Skip to content

Commit ca5fcdc

Browse files
♻️ refactor RUT to use new transactional context (#6874)
1 parent fd62ccf commit ca5fcdc

20 files changed

+1665
-1637
lines changed

services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rest/dependencies.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,11 @@
44
#
55

66
import logging
7-
from collections.abc import AsyncGenerator, Callable
8-
from typing import Annotated
97

10-
from fastapi import Depends
118
from fastapi.requests import Request
129
from servicelib.fastapi.dependencies import get_app, get_reverse_url_mapper
1310
from sqlalchemy.ext.asyncio import AsyncEngine
1411

15-
from ...services.modules.db.repositories._base import BaseRepository
16-
1712
logger = logging.getLogger(__name__)
1813

1914

@@ -23,15 +18,6 @@ def get_resource_tracker_db_engine(request: Request) -> AsyncEngine:
2318
return engine
2419

2520

26-
def get_repository(repo_type: type[BaseRepository]) -> Callable:
27-
async def _get_repo(
28-
engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
29-
) -> AsyncGenerator[BaseRepository, None]:
30-
yield repo_type(db_engine=engine)
31-
32-
return _get_repo
33-
34-
3521
assert get_reverse_url_mapper # nosec
3622
assert get_app # nosec
3723

services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rpc/_resource_tracker.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929

3030
from ...core.settings import ApplicationSettings
3131
from ...services import pricing_plans, pricing_units, service_runs
32-
from ...services.modules.db.repositories.resource_tracker import (
33-
ResourceTrackerRepository,
34-
)
3532
from ...services.modules.s3 import get_s3_client
3633

3734
router = RPCRouter()
@@ -56,7 +53,7 @@ async def get_service_run_page(
5653
return await service_runs.list_service_runs(
5754
user_id=user_id,
5855
product_name=product_name,
59-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
56+
db_engine=app.state.engine,
6057
limit=limit,
6158
offset=offset,
6259
wallet_id=wallet_id,
@@ -87,7 +84,7 @@ async def export_service_runs(
8784
s3_region=s3_settings.S3_REGION,
8885
user_id=user_id,
8986
product_name=product_name,
90-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
87+
db_engine=app.state.engine,
9188
wallet_id=wallet_id,
9289
access_all_wallet_usage=access_all_wallet_usage,
9390
order_by=order_by,
@@ -111,7 +108,7 @@ async def get_osparc_credits_aggregated_usages_page(
111108
return await service_runs.get_osparc_credits_aggregated_usages_page(
112109
user_id=user_id,
113110
product_name=product_name,
114-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
111+
db_engine=app.state.engine,
115112
aggregated_by=aggregated_by,
116113
time_period=time_period,
117114
limit=limit,
@@ -134,7 +131,7 @@ async def get_pricing_plan(
134131
return await pricing_plans.get_pricing_plan(
135132
product_name=product_name,
136133
pricing_plan_id=pricing_plan_id,
137-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
134+
db_engine=app.state.engine,
138135
)
139136

140137

@@ -146,7 +143,7 @@ async def list_pricing_plans(
146143
) -> list[PricingPlanGet]:
147144
return await pricing_plans.list_pricing_plans_by_product(
148145
product_name=product_name,
149-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
146+
db_engine=app.state.engine,
150147
)
151148

152149

@@ -158,7 +155,7 @@ async def create_pricing_plan(
158155
) -> PricingPlanGet:
159156
return await pricing_plans.create_pricing_plan(
160157
data=data,
161-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
158+
db_engine=app.state.engine,
162159
)
163160

164161

@@ -172,7 +169,7 @@ async def update_pricing_plan(
172169
return await pricing_plans.update_pricing_plan(
173170
product_name=product_name,
174171
data=data,
175-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
172+
db_engine=app.state.engine,
176173
)
177174

178175

@@ -191,7 +188,7 @@ async def get_pricing_unit(
191188
product_name=product_name,
192189
pricing_plan_id=pricing_plan_id,
193190
pricing_unit_id=pricing_unit_id,
194-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
191+
db_engine=app.state.engine,
195192
)
196193

197194

@@ -205,7 +202,7 @@ async def create_pricing_unit(
205202
return await pricing_units.create_pricing_unit(
206203
product_name=product_name,
207204
data=data,
208-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
205+
db_engine=app.state.engine,
209206
)
210207

211208

@@ -219,7 +216,7 @@ async def update_pricing_unit(
219216
return await pricing_units.update_pricing_unit(
220217
product_name=product_name,
221218
data=data,
222-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
219+
db_engine=app.state.engine,
223220
)
224221

225222

@@ -238,7 +235,7 @@ async def list_connected_services_to_pricing_plan_by_pricing_plan(
238235
] = await pricing_plans.list_connected_services_to_pricing_plan_by_pricing_plan(
239236
product_name=product_name,
240237
pricing_plan_id=pricing_plan_id,
241-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
238+
db_engine=app.state.engine,
242239
)
243240
return output
244241

@@ -257,5 +254,5 @@ async def connect_service_to_pricing_plan(
257254
pricing_plan_id=pricing_plan_id,
258255
service_key=service_key,
259256
service_version=service_version,
260-
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
257+
db_engine=app.state.engine,
261258
)

services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/background_task_periodic_heartbeat_check.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
ServiceRunStatus,
1111
)
1212
from pydantic import NonNegativeInt, PositiveInt
13+
from sqlalchemy.ext.asyncio import AsyncEngine
1314

1415
from ..core.settings import ApplicationSettings
1516
from ..models.credit_transactions import CreditTransactionCreditsAndStatusUpdate
1617
from ..models.service_runs import ServiceRunStoppedAtUpdate
17-
from .modules.db.repositories.resource_tracker import ResourceTrackerRepository
18+
from .modules.db import credit_transactions_db, service_runs_db
1819
from .utils import compute_service_run_credit_costs, make_negative
1920

2021
_logger = logging.getLogger(__name__)
@@ -23,7 +24,7 @@
2324

2425

2526
async def _check_service_heartbeat(
26-
resource_tracker_repo: ResourceTrackerRepository,
27+
db_engine: AsyncEngine,
2728
base_start_timestamp: datetime,
2829
resource_usage_tracker_missed_heartbeat_interval: timedelta,
2930
resource_usage_tracker_missed_heartbeat_counter_fail: NonNegativeInt,
@@ -55,21 +56,24 @@ async def _check_service_heartbeat(
5556
missed_heartbeat_counter,
5657
)
5758
await _close_unhealthy_service(
58-
resource_tracker_repo, service_run_id, base_start_timestamp
59+
db_engine, service_run_id, base_start_timestamp
5960
)
6061
else:
6162
_logger.warning(
6263
"Service run id: %s missed heartbeat. Counter %s",
6364
service_run_id,
6465
missed_heartbeat_counter,
6566
)
66-
await resource_tracker_repo.update_service_missed_heartbeat_counter(
67-
service_run_id, last_heartbeat_at, missed_heartbeat_counter
67+
await service_runs_db.update_service_missed_heartbeat_counter(
68+
db_engine,
69+
service_run_id=service_run_id,
70+
last_heartbeat_at=last_heartbeat_at,
71+
missed_heartbeat_counter=missed_heartbeat_counter,
6872
)
6973

7074

7175
async def _close_unhealthy_service(
72-
resource_tracker_repo: ResourceTrackerRepository,
76+
db_engine: AsyncEngine,
7377
service_run_id: ServiceRunId,
7478
base_start_timestamp: datetime,
7579
):
@@ -80,8 +84,8 @@ async def _close_unhealthy_service(
8084
service_run_status=ServiceRunStatus.ERROR,
8185
service_run_status_msg="Service missed more heartbeats. It's considered unhealthy.",
8286
)
83-
running_service = await resource_tracker_repo.update_service_run_stopped_at(
84-
update_service_run_stopped_at
87+
running_service = await service_runs_db.update_service_run_stopped_at(
88+
db_engine, data=update_service_run_stopped_at
8589
)
8690

8791
if running_service is None:
@@ -108,8 +112,8 @@ async def _close_unhealthy_service(
108112
else CreditTransactionStatus.BILLED
109113
),
110114
)
111-
await resource_tracker_repo.update_credit_transaction_credits_and_status(
112-
update_credit_transaction
115+
await credit_transactions_db.update_credit_transaction_credits_and_status(
116+
db_engine, data=update_credit_transaction
113117
)
114118

115119

@@ -118,27 +122,26 @@ async def periodic_check_of_running_services_task(app: FastAPI) -> None:
118122

119123
# This check runs across all products
120124
app_settings: ApplicationSettings = app.state.settings
121-
resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository(
122-
db_engine=app.state.engine
123-
)
125+
_db_engine = app.state.engine
124126

125127
base_start_timestamp = datetime.now(tz=timezone.utc)
126128

127129
# Get all current running services (across all products)
128-
total_count: PositiveInt = (
129-
await resource_tracker_repo.total_service_runs_with_running_status_across_all_products()
130+
total_count: PositiveInt = await service_runs_db.total_service_runs_with_running_status_across_all_products(
131+
_db_engine
130132
)
131133

132134
for offset in range(0, total_count, _BATCH_SIZE):
133-
batch_check_services = await resource_tracker_repo.list_service_runs_with_running_status_across_all_products(
135+
batch_check_services = await service_runs_db.list_service_runs_with_running_status_across_all_products(
136+
_db_engine,
134137
offset=offset,
135138
limit=_BATCH_SIZE,
136139
)
137140

138141
await asyncio.gather(
139142
*(
140143
_check_service_heartbeat(
141-
resource_tracker_repo=resource_tracker_repo,
144+
db_engine=_db_engine,
142145
base_start_timestamp=base_start_timestamp,
143146
resource_usage_tracker_missed_heartbeat_interval=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_INTERVAL_SEC,
144147
resource_usage_tracker_missed_heartbeat_counter_fail=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_COUNTER_FAIL,

services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/credit_transactions.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,18 @@
1313
)
1414
from models_library.wallets import WalletID
1515
from servicelib.rabbitmq import RabbitMQClient
16+
from sqlalchemy.ext.asyncio import AsyncEngine
1617

17-
from ..api.rest.dependencies import get_repository
18+
from ..api.rest.dependencies import get_resource_tracker_db_engine
1819
from ..models.credit_transactions import CreditTransactionCreate
19-
from .modules.db.repositories.resource_tracker import ResourceTrackerRepository
20+
from .modules.db import credit_transactions_db
2021
from .modules.rabbitmq import get_rabbitmq_client_from_request
2122
from .utils import sum_credit_transactions_and_publish_to_rabbitmq
2223

2324

2425
async def create_credit_transaction(
2526
credit_transaction_create_body: CreditTransactionCreateBody,
26-
resource_tracker_repo: Annotated[
27-
ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository))
28-
],
27+
db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
2928
rabbitmq_client: Annotated[
3029
RabbitMQClient, Depends(get_rabbitmq_client_from_request)
3130
],
@@ -47,12 +46,12 @@ async def create_credit_transaction(
4746
created_at=credit_transaction_create_body.created_at,
4847
last_heartbeat_at=credit_transaction_create_body.created_at,
4948
)
50-
transaction_id = await resource_tracker_repo.create_credit_transaction(
51-
transaction_create
49+
transaction_id = await credit_transactions_db.create_credit_transaction(
50+
db_engine, data=transaction_create
5251
)
5352

5453
await sum_credit_transactions_and_publish_to_rabbitmq(
55-
resource_tracker_repo,
54+
db_engine,
5655
rabbitmq_client,
5756
credit_transaction_create_body.product_name,
5857
credit_transaction_create_body.wallet_id,
@@ -64,10 +63,8 @@ async def create_credit_transaction(
6463
async def sum_credit_transactions_by_product_and_wallet(
6564
product_name: ProductName,
6665
wallet_id: WalletID,
67-
resource_tracker_repo: Annotated[
68-
ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository))
69-
],
66+
db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
7067
) -> WalletTotalCredits:
71-
return await resource_tracker_repo.sum_credit_transactions_by_product_and_wallet(
72-
product_name, wallet_id
68+
return await credit_transactions_db.sum_credit_transactions_by_product_and_wallet(
69+
db_engine, product_name=product_name, wallet_id=wallet_id
7370
)

0 commit comments

Comments
 (0)