|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import math |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from datetime import datetime, timezone |
@@ -137,6 +138,11 @@ def __init__(self, configuration: Configuration, client: ApifyClientAsync) -> No |
137 | 138 | self._not_ppe_warning_printed = False |
138 | 139 | self.active = False |
139 | 140 |
|
| 141 | + self._charge_lock = asyncio.Lock() |
| 142 | + """Lock to synchronize charge operations and prevent race conditions between Actor.charge |
| 143 | + and Actor.push_data calls. |
| 144 | + """ |
| 145 | + |
140 | 146 | async def __aenter__(self) -> None: |
141 | 147 | """Initialize the charging manager - this is called by the `Actor` class and shouldn't be invoked manually.""" |
142 | 148 | # Validate config |
@@ -223,57 +229,55 @@ def calculate_chargeable() -> dict[str, int | None]: |
223 | 229 | chargeable_within_limit=calculate_chargeable(), |
224 | 230 | ) |
225 | 231 |
|
226 | | - # START OF CRITICAL SECTION - no awaits here |
227 | | - |
228 | | - # Determine the maximum amount of events that can be charged within the budget |
229 | | - max_chargeable = self.calculate_max_event_charge_count_within_limit(event_name) |
230 | | - charged_count = min(count, max_chargeable if max_chargeable is not None else count) |
| 232 | + # Acquire lock to prevent race conditions between concurrent charge calls |
| 233 | + # (e.g., when Actor.push_data with charging is called concurrently with Actor.charge). |
| 234 | + async with self._charge_lock: |
| 235 | + # Determine the maximum amount of events that can be charged within the budget |
| 236 | + max_chargeable = self.calculate_max_event_charge_count_within_limit(event_name) |
| 237 | + charged_count = min(count, max_chargeable if max_chargeable is not None else count) |
| 238 | + |
| 239 | + if charged_count == 0: |
| 240 | + return ChargeResult( |
| 241 | + event_charge_limit_reached=True, |
| 242 | + charged_count=0, |
| 243 | + chargeable_within_limit=calculate_chargeable(), |
| 244 | + ) |
231 | 245 |
|
232 | | - if charged_count == 0: |
233 | | - return ChargeResult( |
234 | | - event_charge_limit_reached=True, |
235 | | - charged_count=0, |
236 | | - chargeable_within_limit=calculate_chargeable(), |
| 246 | + pricing_info = self._pricing_info.get( |
| 247 | + event_name, |
| 248 | + PricingInfoItem( |
| 249 | + # Use a nonzero price for local development so that the maximum budget can be reached. |
| 250 | + price=Decimal() if self._is_at_home else Decimal(1), |
| 251 | + title=f"Unknown event '{event_name}'", |
| 252 | + ), |
237 | 253 | ) |
238 | 254 |
|
239 | | - pricing_info = self._pricing_info.get( |
240 | | - event_name, |
241 | | - PricingInfoItem( |
242 | | - price=Decimal() |
243 | | - if self._is_at_home |
244 | | - else Decimal(1), # Use a nonzero price for local development so that the maximum budget can be reached, |
245 | | - title=f"Unknown event '{event_name}'", |
246 | | - ), |
247 | | - ) |
248 | | - |
249 | | - # Update the charging state |
250 | | - self._charging_state.setdefault(event_name, ChargingStateItem(0, Decimal())) |
251 | | - self._charging_state[event_name].charge_count += charged_count |
252 | | - self._charging_state[event_name].total_charged_amount += charged_count * pricing_info.price |
253 | | - |
254 | | - # END OF CRITICAL SECTION |
255 | | - |
256 | | - # If running on the platform, call the charge endpoint |
257 | | - if self._is_at_home: |
258 | | - if self._actor_run_id is None: |
259 | | - raise RuntimeError('Actor run ID not configured') |
260 | | - |
261 | | - if event_name in self._pricing_info: |
262 | | - await self._client.run(self._actor_run_id).charge(event_name, charged_count) |
263 | | - else: |
264 | | - logger.warning(f"Attempting to charge for an unknown event '{event_name}'") |
265 | | - |
266 | | - # Log the charged operation (if enabled) |
267 | | - if self._charging_log_dataset: |
268 | | - await self._charging_log_dataset.push_data( |
269 | | - { |
270 | | - 'event_name': event_name, |
271 | | - 'event_title': pricing_info.title, |
272 | | - 'event_price_usd': round(pricing_info.price, 3), |
273 | | - 'charged_count': charged_count, |
274 | | - 'timestamp': datetime.now(timezone.utc).isoformat(), |
275 | | - } |
276 | | - ) |
| 255 | + # Update the charging state |
| 256 | + self._charging_state.setdefault(event_name, ChargingStateItem(0, Decimal())) |
| 257 | + self._charging_state[event_name].charge_count += charged_count |
| 258 | + self._charging_state[event_name].total_charged_amount += charged_count * pricing_info.price |
| 259 | + |
| 260 | + # If running on the platform, call the charge endpoint |
| 261 | + if self._is_at_home: |
| 262 | + if self._actor_run_id is None: |
| 263 | + raise RuntimeError('Actor run ID not configured') |
| 264 | + |
| 265 | + if event_name in self._pricing_info: |
| 266 | + await self._client.run(self._actor_run_id).charge(event_name, charged_count) |
| 267 | + else: |
| 268 | + logger.warning(f"Attempting to charge for an unknown event '{event_name}'") |
| 269 | + |
| 270 | + # Log the charged operation (if enabled) |
| 271 | + if self._charging_log_dataset: |
| 272 | + await self._charging_log_dataset.push_data( |
| 273 | + { |
| 274 | + 'event_name': event_name, |
| 275 | + 'event_title': pricing_info.title, |
| 276 | + 'event_price_usd': round(pricing_info.price, 3), |
| 277 | + 'charged_count': charged_count, |
| 278 | + 'timestamp': datetime.now(timezone.utc).isoformat(), |
| 279 | + } |
| 280 | + ) |
277 | 281 |
|
278 | 282 | # If it is not possible to charge the full amount, log that fact |
279 | 283 | if charged_count < count: |
|
0 commit comments