Skip to content

Commit 2f6ab0a

Browse files
🎨 Check for zero credits (if pricing unit cost is greater than 0) (#5835)
1 parent 4892c0d commit 2f6ab0a

File tree

11 files changed

+171
-61
lines changed

11 files changed

+171
-61
lines changed

packages/models-library/src/models_library/wallets.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,17 @@ class WalletStatus(StrAutoEnum):
1818
class WalletInfo(BaseModel):
1919
wallet_id: WalletID
2020
wallet_name: str
21+
wallet_credit_amount: Decimal
2122

2223
class Config:
2324
schema_extra: ClassVar[dict[str, Any]] = {
24-
"examples": [{"wallet_id": 1, "wallet_name": "My Wallet"}]
25+
"examples": [
26+
{
27+
"wallet_id": 1,
28+
"wallet_name": "My Wallet",
29+
"wallet_credit_amount": Decimal(10),
30+
}
31+
]
2532
}
2633

2734

services/director-v2/openapi.json

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"info": {
44
"title": "simcore-service-director-v2",
55
"description": "Orchestrates the pipeline of services defined by the user",
6-
"version": "2.2.0"
6+
"version": "2.3.0"
77
},
88
"servers": [
99
{
@@ -2494,7 +2494,8 @@
24942494
},
24952495
"wallet_info": {
24962496
"wallet_id": 1,
2497-
"wallet_name": "My Wallet"
2497+
"wallet_name": "My Wallet",
2498+
"wallet_credit_amount": 10
24982499
},
24992500
"pricing_info": {
25002501
"pricing_plan_id": 1,
@@ -3859,12 +3860,17 @@
38593860
"wallet_name": {
38603861
"type": "string",
38613862
"title": "Wallet Name"
3863+
},
3864+
"wallet_credit_amount": {
3865+
"type": "number",
3866+
"title": "Wallet Credit Amount"
38623867
}
38633868
},
38643869
"type": "object",
38653870
"required": [
38663871
"wallet_id",
3867-
"wallet_name"
3872+
"wallet_name",
3873+
"wallet_credit_amount"
38683874
],
38693875
"title": "WalletInfo"
38703876
},

services/director-v2/src/simcore_service_director_v2/api/routes/computations.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
PricingPlanUnitNotFoundError,
5656
ProjectNotFoundError,
5757
SchedulerError,
58+
WalletNotEnoughCreditsError,
5859
)
5960
from ...models.comp_pipelines import CompPipelineAtDB
6061
from ...models.comp_runs import CompRunsAtDB, ProjectMetadataDict, RunMetadataDict
@@ -318,7 +319,7 @@ async def create_computation( # noqa: PLR0913
318319
user_id=computation.user_id,
319320
product_name=computation.product_name,
320321
rut_client=rut_client,
321-
is_wallet=bool(computation.wallet_info),
322+
wallet_info=computation.wallet_info,
322323
rabbitmq_rpc_client=rpc_client,
323324
)
324325

@@ -393,6 +394,10 @@ async def create_computation( # noqa: PLR0913
393394
) from e
394395
except ConfigurationError as e:
395396
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"{e}") from e
397+
except WalletNotEnoughCreditsError as e:
398+
raise HTTPException(
399+
status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"{e}"
400+
) from e
396401

397402

398403
@router.get(

services/director-v2/src/simcore_service_director_v2/core/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class ComputationalTaskNotFoundError(PydanticErrorMixin, DirectorError):
114114
msg_template = "Computational task {node_id} not found"
115115

116116

117+
class WalletNotEnoughCreditsError(PydanticErrorMixin, DirectorError):
118+
msg_template = "Wallet '{wallet_name}' has {wallet_credit_amount} credits."
119+
120+
117121
#
118122
# SCHEDULER ERRORS
119123
#
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from decimal import Decimal
2+
from typing import Any, ClassVar
3+
4+
from models_library.resource_tracker import (
5+
PricingPlanId,
6+
PricingUnitCostId,
7+
PricingUnitId,
8+
)
9+
from pydantic import BaseModel
10+
11+
12+
class PricingInfo(BaseModel):
13+
pricing_plan_id: PricingPlanId
14+
pricing_unit_id: PricingUnitId
15+
pricing_unit_cost_id: PricingUnitCostId
16+
pricing_unit_cost: Decimal
17+
18+
class Config:
19+
schema_extra: ClassVar[dict[str, Any]] = {
20+
"examples": [
21+
{
22+
"pricing_plan_id": 1,
23+
"pricing_unit_id": 1,
24+
"pricing_unit_cost_id": 1,
25+
"pricing_unit_cost": Decimal(10),
26+
}
27+
]
28+
}

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from models_library.projects_nodes_io import NodeID
1111
from models_library.projects_state import RunningState
1212
from models_library.users import UserID
13+
from models_library.wallets import WalletInfo
1314
from servicelib.logging_utils import log_context
1415
from servicelib.rabbitmq import RabbitMQRPCClient
1516
from servicelib.utils import logged_gather
@@ -94,7 +95,7 @@ async def upsert_tasks_from_project(
9495
user_id: UserID,
9596
product_name: str,
9697
rut_client: ResourceUsageTrackerClient,
97-
is_wallet: bool,
98+
wallet_info: WalletInfo | None,
9899
rabbitmq_rpc_client: RabbitMQRPCClient,
99100
) -> list[CompTaskAtDB]:
100101
# NOTE: really do an upsert here because of issue https://github.com/ITISFoundation/osparc-simcore/issues/2125
@@ -110,7 +111,7 @@ async def upsert_tasks_from_project(
110111
product_name=product_name,
111112
connection=conn,
112113
rut_client=rut_client,
113-
is_wallet=is_wallet,
114+
wallet_info=wallet_info,
114115
rabbitmq_rpc_client=rabbitmq_rpc_client,
115116
)
116117
# get current tasks

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from decimal import Decimal
34
from typing import Any, Final, cast
45

56
import aiopg.sa
@@ -15,7 +16,7 @@
1516
from models_library.projects_nodes import Node
1617
from models_library.projects_nodes_io import NodeID
1718
from models_library.projects_state import RunningState
18-
from models_library.resource_tracker import HardwareInfo, PricingInfo
19+
from models_library.resource_tracker import HardwareInfo
1920
from models_library.service_settings_labels import (
2021
SimcoreServiceLabels,
2122
SimcoreServiceSettingsLabel,
@@ -34,6 +35,7 @@
3435
ServiceResourcesDictHelpers,
3536
)
3637
from models_library.users import UserID
38+
from models_library.wallets import ZERO_CREDITS, WalletInfo
3739
from pydantic import parse_obj_as
3840
from servicelib.rabbitmq import (
3941
RabbitMQRPCClient,
@@ -45,8 +47,13 @@
4547
)
4648
from simcore_postgres_database.utils_projects_nodes import ProjectNodesRepo
4749

48-
from .....core.errors import ClustersKeeperNotAvailableError, ConfigurationError
50+
from .....core.errors import (
51+
ClustersKeeperNotAvailableError,
52+
ConfigurationError,
53+
WalletNotEnoughCreditsError,
54+
)
4955
from .....models.comp_tasks import CompTaskAtDB, Image, NodeSchema
56+
from .....models.pricing import PricingInfo
5057
from .....modules.resource_usage_tracker_client import ResourceUsageTrackerClient
5158
from .....utils.comp_scheduler import COMPLETED_STATES
5259
from .....utils.computations import to_node_class
@@ -201,17 +208,12 @@ async def _get_pricing_and_hardware_infos(
201208
# this will need to move away and be in sync.
202209
if output:
203210
pricing_plan_id, pricing_unit_id = output
204-
pricing_unit_get = await rut_client.get_pricing_unit(
205-
product_name, pricing_plan_id, pricing_unit_id
206-
)
207-
pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id
208-
aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances
209211
else:
210212
(
211213
pricing_plan_id,
212214
pricing_unit_id,
213-
pricing_unit_cost_id,
214-
aws_ec2_instances,
215+
_,
216+
_,
215217
) = await rut_client.get_default_pricing_and_hardware_info(
216218
product_name, node_key, node_version
217219
)
@@ -222,10 +224,17 @@ async def _get_pricing_and_hardware_infos(
222224
pricing_unit_id=pricing_unit_id,
223225
)
224226

227+
pricing_unit_get = await rut_client.get_pricing_unit(
228+
product_name, pricing_plan_id, pricing_unit_id
229+
)
230+
pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id
231+
aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances
232+
225233
pricing_info = PricingInfo(
226234
pricing_plan_id=pricing_plan_id,
227235
pricing_unit_id=pricing_unit_id,
228236
pricing_unit_cost_id=pricing_unit_cost_id,
237+
pricing_unit_cost=pricing_unit_get.current_cost_per_unit,
229238
)
230239
hardware_info = HardwareInfo(aws_ec2_instances=aws_ec2_instances)
231240
return pricing_info, hardware_info
@@ -323,7 +332,7 @@ async def generate_tasks_list_from_project(
323332
product_name: str,
324333
connection: aiopg.sa.connection.SAConnection,
325334
rut_client: ResourceUsageTrackerClient,
326-
is_wallet: bool,
335+
wallet_info: WalletInfo | None,
327336
rabbitmq_rpc_client: RabbitMQRPCClient,
328337
) -> list[CompTaskAtDB]:
329338
list_comp_tasks = []
@@ -373,17 +382,29 @@ async def generate_tasks_list_from_project(
373382
pricing_info, hardware_info = await _get_pricing_and_hardware_infos(
374383
connection,
375384
rut_client,
376-
is_wallet=is_wallet,
385+
is_wallet=bool(wallet_info),
377386
project_id=project.uuid,
378387
node_id=NodeID(node_id),
379388
product_name=product_name,
380389
node_key=node.key,
381390
node_version=node.version,
382391
)
392+
# Check for zero credits (if pricing unit is greater than 0).
393+
if (
394+
wallet_info
395+
and pricing_info
396+
and pricing_info.pricing_unit_cost > Decimal(0)
397+
and wallet_info.wallet_credit_amount <= ZERO_CREDITS
398+
):
399+
raise WalletNotEnoughCreditsError(
400+
wallet_name=wallet_info.wallet_name,
401+
wallet_credit_amount=wallet_info.wallet_credit_amount,
402+
)
403+
383404
assert rabbitmq_rpc_client # nosec
384405
await _update_project_node_resources_from_hardware_info(
385406
connection,
386-
is_wallet=is_wallet,
407+
is_wallet=bool(wallet_info),
387408
project_id=project.uuid,
388409
node_id=NodeID(node_id),
389410
hardware_info=hardware_info,
@@ -420,7 +441,9 @@ async def generate_tasks_list_from_project(
420441
last_heartbeat=None,
421442
created=arrow.utcnow().datetime,
422443
modified=arrow.utcnow().datetime,
423-
pricing_info=pricing_info.dict() if pricing_info else None,
444+
pricing_info=pricing_info.dict(exclude={"pricing_unit_cost"})
445+
if pricing_info
446+
else None,
424447
hardware_info=hardware_info,
425448
)
426449

services/director-v2/tests/unit/with_dbs/test_api_route_computations.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import re
1111
import urllib.parse
1212
from collections.abc import Awaitable, Callable, Iterator
13+
from decimal import Decimal
1314
from pathlib import Path
1415
from random import choice
1516
from typing import Any
@@ -29,6 +30,7 @@
2930
from models_library.api_schemas_directorv2.services import ServiceExtras
3031
from models_library.api_schemas_resource_usage_tracker.pricing_plans import (
3132
PricingPlanGet,
33+
PricingUnitGet,
3234
)
3335
from models_library.basic_types import VersionStr
3436
from models_library.clusters import DEFAULT_CLUSTER_ID, Cluster, ClusterID
@@ -291,19 +293,38 @@ def _mocked_service_default_pricing_plan(
291293
200, json=jsonable_encoder(default_pricing_plan, by_alias=True)
292294
)
293295

296+
def _mocked_get_pricing_unit(request, pricing_plan_id: int) -> httpx.Response:
297+
return httpx.Response(
298+
200,
299+
json=jsonable_encoder(
300+
(
301+
default_pricing_plan.pricing_units[0]
302+
if default_pricing_plan.pricing_units
303+
else PricingUnitGet.Config.schema_extra["examples"][0]
304+
),
305+
by_alias=True,
306+
),
307+
)
308+
294309
# pylint: disable=not-context-manager
295310
with respx.mock(
296311
base_url=minimal_app.state.settings.DIRECTOR_V2_RESOURCE_USAGE_TRACKER.api_base_url,
297312
assert_all_called=False,
298313
assert_all_mocked=True,
299314
) as respx_mock:
315+
300316
respx_mock.get(
301317
re.compile(
302318
r"services/(?P<service_key>simcore/services/(comp|dynamic|frontend)/[^/]+)/(?P<service_version>[^\.]+.[^\.]+.[^/\?]+)/pricing-plan.+"
303319
),
304320
name="get_service_default_pricing_plan",
305321
).mock(side_effect=_mocked_service_default_pricing_plan)
306322

323+
respx_mock.get(
324+
re.compile(r"pricing-plans/(?P<pricing_plan_id>\d+)/pricing-units.+"),
325+
name="get_pricing_unit",
326+
).mock(side_effect=_mocked_get_pricing_unit)
327+
307328
yield respx_mock
308329

309330

@@ -384,7 +405,11 @@ async def test_create_computation(
384405

385406
@pytest.fixture
386407
def wallet_info(faker: Faker) -> WalletInfo:
387-
return WalletInfo(wallet_id=faker.pyint(), wallet_name=faker.name())
408+
return WalletInfo(
409+
wallet_id=faker.pyint(),
410+
wallet_name=faker.name(),
411+
wallet_credit_amount=Decimal(faker.pyint(min_value=12, max_value=129312)),
412+
)
388413

389414

390415
@pytest.fixture
@@ -483,12 +508,16 @@ async def test_create_computation_with_wallet(
483508
assert response.status_code == status.HTTP_201_CREATED, response.text
484509
if default_pricing_plan_aws_ec2_type:
485510
mocked_clusters_keeper_service_get_instance_type_details.assert_called()
486-
assert mocked_resource_usage_tracker_service_fcts.calls.call_count == len(
487-
[
488-
v
489-
for v in proj.workbench.values()
490-
if to_node_class(v.key) != NodeClass.FRONTEND
491-
]
511+
assert (
512+
mocked_resource_usage_tracker_service_fcts.calls.call_count
513+
== len(
514+
[
515+
v
516+
for v in proj.workbench.values()
517+
if to_node_class(v.key) != NodeClass.FRONTEND
518+
]
519+
)
520+
* 2
492521
)
493522
# check the project nodes were really overriden now
494523
async with aiopg_engine.acquire() as connection:
@@ -540,7 +569,7 @@ async def test_create_computation_with_wallet(
540569

541570
@pytest.mark.parametrize(
542571
"default_pricing_plan",
543-
[PricingPlanGet.Config.schema_extra["examples"][0]],
572+
[PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])],
544573
)
545574
async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_raises_409(
546575
minimal_configuration: None,
@@ -578,7 +607,7 @@ async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_rai
578607

579608
@pytest.mark.parametrize(
580609
"default_pricing_plan",
581-
[PricingPlanGet.Config.schema_extra["examples"][0]],
610+
[PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])],
582611
)
583612
async def test_create_computation_with_wallet_with_no_clusters_keeper_raises_503(
584613
minimal_configuration: None,

0 commit comments

Comments
 (0)