|
4 | 4 |
|
5 | 5 | from sqlalchemy import and_, delete, or_, select, update |
6 | 6 | from sqlalchemy.ext.asyncio import AsyncSession |
7 | | -from sqlalchemy.orm import selectinload |
| 7 | +from sqlalchemy.orm import noload, selectinload |
8 | 8 |
|
9 | 9 | from app.core.memberships import schemas_memberships |
10 | 10 | from app.core.myeclpay import models_myeclpay, schemas_myeclpay |
| 11 | +from app.core.myeclpay.exceptions_myeclpay import WalletNotFoundOnUpdateError |
11 | 12 | from app.core.myeclpay.types_myeclpay import ( |
12 | 13 | TransactionStatus, |
13 | 14 | WalletDeviceStatus, |
@@ -441,11 +442,16 @@ async def get_wallet( |
441 | 442 | wallet_id: UUID, |
442 | 443 | db: AsyncSession, |
443 | 444 | ) -> models_myeclpay.Wallet | None: |
444 | | - result = await db.execute( |
445 | | - select(models_myeclpay.Wallet).where( |
| 445 | + # We lock the wallet `for update` to prevent race conditions |
| 446 | + request = ( |
| 447 | + select(models_myeclpay.Wallet) |
| 448 | + .where( |
446 | 449 | models_myeclpay.Wallet.id == wallet_id, |
447 | | - ), |
| 450 | + ) |
| 451 | + .with_for_update(of=models_myeclpay.Wallet) |
448 | 452 | ) |
| 453 | + |
| 454 | + result = await db.execute(request) |
449 | 455 | return result.scalars().first() |
450 | 456 |
|
451 | 457 |
|
@@ -512,11 +518,23 @@ async def increment_wallet_balance( |
512 | 518 | """ |
513 | 519 | Append `amount` to the wallet balance. |
514 | 520 | """ |
515 | | - await db.execute( |
516 | | - update(models_myeclpay.Wallet) |
| 521 | + # Prevent a race condition by locking the wallet row |
| 522 | + # as we don't want the balance to be modified between the select and the update. |
| 523 | + request = ( |
| 524 | + select(models_myeclpay.Wallet) |
517 | 525 | .where(models_myeclpay.Wallet.id == wallet_id) |
518 | | - .values(balance=models_myeclpay.Wallet.balance + amount), |
| 526 | + .options( |
| 527 | + noload(models_myeclpay.Wallet.store), |
| 528 | + noload(models_myeclpay.Wallet.user), |
| 529 | + ) |
| 530 | + .with_for_update() |
519 | 531 | ) |
| 532 | + result = await db.execute(request) |
| 533 | + wallet = result.scalars().first() |
| 534 | + |
| 535 | + if wallet is None: |
| 536 | + raise WalletNotFoundOnUpdateError(wallet_id=wallet_id) |
| 537 | + wallet.balance += amount |
520 | 538 |
|
521 | 539 |
|
522 | 540 | async def create_user_payment( |
@@ -600,12 +618,16 @@ async def get_transaction( |
600 | 618 | transaction_id: UUID, |
601 | 619 | db: AsyncSession, |
602 | 620 | ) -> schemas_myeclpay.Transaction | None: |
| 621 | + # We lock the transaction `for update` to prevent |
| 622 | + # race conditions |
603 | 623 | result = ( |
604 | 624 | ( |
605 | 625 | await db.execute( |
606 | | - select(models_myeclpay.Transaction).where( |
| 626 | + select(models_myeclpay.Transaction) |
| 627 | + .where( |
607 | 628 | models_myeclpay.Transaction.id == transaction_id, |
608 | | - ), |
| 629 | + ) |
| 630 | + .with_for_update(of=models_myeclpay.Transaction), |
609 | 631 | ) |
610 | 632 | ) |
611 | 633 | .scalars() |
|
0 commit comments