Skip to content

Commit 58dd813

Browse files
committed
Add support of cluster level defaults
1 parent 7db4d6c commit 58dd813

File tree

5 files changed

+282
-20
lines changed

5 files changed

+282
-20
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ repos:
2020
- id: shed
2121
args:
2222
- --refactor
23-
- --py37-plus
2423
types_or:
2524
- python
2625
- markdown

neuro_admin_client/__init__.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,18 @@ async def get_cluster(self, name: str) -> Cluster:
8585
async def create_cluster(
8686
self,
8787
name: str,
88+
default_credits: Decimal | None = None,
89+
default_quota: Quota = Quota(),
8890
) -> Cluster:
8991
...
9092

93+
@abstractmethod
94+
async def update_cluster(
95+
self,
96+
cluster: Cluster,
97+
) -> None:
98+
...
99+
91100
@abstractmethod
92101
async def delete_cluster(self, name: str) -> Cluster:
93102
...
@@ -414,6 +423,8 @@ async def create_org_cluster(
414423
org_name: str,
415424
quota: Quota = Quota(),
416425
balance: Balance = Balance(),
426+
default_quota: Quota = Quota(),
427+
default_credits: Decimal | None = None,
417428
) -> OrgCluster:
418429
...
419430

@@ -710,6 +721,10 @@ async def update_user(
710721
def _parse_cluster_payload(self, payload: dict[str, Any]) -> Cluster:
711722
return Cluster(
712723
name=payload["name"],
724+
default_credits=Decimal(payload["default_credits"])
725+
if payload.get("default_credits")
726+
else None,
727+
default_quota=self._parse_quota(payload.get("default_quota")),
713728
)
714729

715730
async def list_clusters(self) -> list[Cluster]:
@@ -730,15 +745,42 @@ async def get_cluster(self, name: str) -> Cluster:
730745
async def create_cluster(
731746
self,
732747
name: str,
748+
default_credits: Decimal | None = None,
749+
default_quota: Quota = Quota(),
733750
) -> Cluster:
734-
payload = {
751+
payload: dict[str, Any] = {
735752
"name": name,
753+
"default_quota": {},
736754
}
755+
if default_credits:
756+
payload["default_credits"] = str(default_credits)
757+
if default_quota.total_running_jobs:
758+
payload["default_quota"]["total_running_jobs"] = str(
759+
default_quota.total_running_jobs
760+
)
737761
async with self._request("POST", "clusters", json=payload) as resp:
738762
resp.raise_for_status()
739763
raw_cluster = await resp.json()
740764
return self._parse_cluster_payload(raw_cluster)
741765

766+
async def update_cluster(
767+
self,
768+
cluster: Cluster,
769+
) -> None:
770+
payload: dict[str, Any] = {
771+
"name": cluster.name,
772+
}
773+
if cluster.default_credits:
774+
payload["default_credits"] = str(cluster.default_credits)
775+
if cluster.default_quota.total_running_jobs:
776+
payload["default_quota"] = {
777+
"total_running_jobs": str(cluster.default_quota.total_running_jobs)
778+
}
779+
async with self._request(
780+
"PUT", f"clusters/{cluster.name}", json=payload
781+
) as resp:
782+
resp.raise_for_status()
783+
742784
async def delete_cluster(self, name: str) -> Cluster:
743785
async with self._request("DELETE", f"clusters/{name}") as resp:
744786
resp.raise_for_status()
@@ -1273,6 +1315,10 @@ def _parse_org_cluster(
12731315
org_name=payload["org_name"],
12741316
balance=self._parse_balance(payload.get("balance")),
12751317
quota=self._parse_quota(payload.get("quota")),
1318+
default_credits=Decimal(payload["default_credits"])
1319+
if payload.get("default_credits")
1320+
else None,
1321+
default_quota=self._parse_quota(payload.get("default_quota")),
12761322
)
12771323

12781324
async def create_org_cluster(
@@ -1281,18 +1327,27 @@ async def create_org_cluster(
12811327
org_name: str,
12821328
quota: Quota = Quota(),
12831329
balance: Balance = Balance(),
1330+
default_quota: Quota = Quota(),
1331+
default_credits: Decimal | None = None,
12841332
) -> OrgCluster:
12851333
payload: dict[str, Any] = {
12861334
"org_name": org_name,
12871335
"quota": {},
12881336
"balance": {},
1337+
"default_quota": {},
12891338
}
12901339
if quota.total_running_jobs is not None:
12911340
payload["quota"]["total_running_jobs"] = quota.total_running_jobs
12921341
if balance.credits is not None:
12931342
payload["balance"]["credits"] = str(balance.credits)
12941343
if balance.spent_credits is not None:
12951344
payload["balance"]["spent_credits"] = str(balance.spent_credits)
1345+
if default_credits:
1346+
payload["default_credits"] = str(default_credits)
1347+
if default_quota.total_running_jobs is not None:
1348+
payload["default_quota"][
1349+
"total_running_jobs"
1350+
] = default_quota.total_running_jobs
12961351
async with self._request(
12971352
"POST",
12981353
f"clusters/{cluster_name}/orgs",
@@ -1332,6 +1387,7 @@ async def update_org_cluster(self, org_cluster: OrgCluster) -> OrgCluster:
13321387
"org_name": org_cluster.org_name,
13331388
"quota": {},
13341389
"balance": {},
1390+
"default_quota": {},
13351391
}
13361392
if org_cluster.quota.total_running_jobs is not None:
13371393
payload["quota"][
@@ -1341,6 +1397,12 @@ async def update_org_cluster(self, org_cluster: OrgCluster) -> OrgCluster:
13411397
payload["balance"]["credits"] = str(org_cluster.balance.credits)
13421398
if org_cluster.balance.spent_credits is not None:
13431399
payload["balance"]["spent_credits"] = str(org_cluster.balance.spent_credits)
1400+
if org_cluster.default_credits:
1401+
payload["default_credits"] = str(org_cluster.default_credits)
1402+
if org_cluster.default_quota.total_running_jobs is not None:
1403+
payload["default_quota"][
1404+
"total_running_jobs"
1405+
] = org_cluster.default_quota.total_running_jobs
13441406
async with self._request(
13451407
"PUT",
13461408
f"clusters/{org_cluster.cluster_name}/orgs/{org_cluster.org_name}",
@@ -1741,7 +1803,7 @@ class AdminClientDummy(AdminClientABC):
17411803
name="user",
17421804
email="email@example.com",
17431805
)
1744-
DUMMY_CLUSTER = Cluster(name="default")
1806+
DUMMY_CLUSTER = Cluster(name="default", default_credits=None, default_quota=Quota())
17451807
DUMMY_CLUSTER_USER = ClusterUserWithInfo(
17461808
cluster_name="default",
17471809
user_name="user",
@@ -1809,9 +1871,17 @@ async def get_cluster(self, name: str) -> Cluster:
18091871
async def create_cluster(
18101872
self,
18111873
name: str,
1874+
default_credits: Decimal | None = None,
1875+
default_quota: Quota = Quota(),
18121876
) -> Cluster:
18131877
return self.DUMMY_CLUSTER
18141878

1879+
async def update_cluster(
1880+
self,
1881+
cluster: Cluster,
1882+
) -> None:
1883+
pass
1884+
18151885
async def delete_cluster(self, name: str) -> Cluster:
18161886
pass
18171887

@@ -2126,6 +2196,8 @@ async def create_org_cluster(
21262196
org_name: str,
21272197
quota: Quota = Quota(),
21282198
balance: Balance = Balance(),
2199+
default_quota: Quota = Quota(),
2200+
default_credits: Decimal | None = None,
21292201
) -> OrgCluster:
21302202
return self.DUMMY_ORG_CLUSTER
21312203

neuro_admin_client/entities.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,26 @@ class User(FullNameMixin):
3737
created_at: Optional[datetime] = None
3838

3939

40+
@dataclass(frozen=True)
41+
class Balance:
42+
credits: Optional[Decimal] = None
43+
spent_credits: Decimal = Decimal(0)
44+
45+
@property
46+
def is_non_positive(self) -> bool:
47+
return self.credits is not None and self.credits <= 0
48+
49+
50+
@dataclass(frozen=True)
51+
class Quota:
52+
total_running_jobs: Optional[int] = None
53+
54+
4055
@dataclass(frozen=True)
4156
class Cluster:
4257
name: str
58+
default_credits: Optional[Decimal]
59+
default_quota: Quota
4360

4461

4562
@dataclass(frozen=True)
@@ -80,27 +97,14 @@ class OrgUserWithInfo(OrgUser):
8097
user_info: UserInfo
8198

8299

83-
@dataclass(frozen=True)
84-
class Balance:
85-
credits: Optional[Decimal] = None
86-
spent_credits: Decimal = Decimal(0)
87-
88-
@property
89-
def is_non_positive(self) -> bool:
90-
return self.credits is not None and self.credits <= 0
91-
92-
93-
@dataclass(frozen=True)
94-
class Quota:
95-
total_running_jobs: Optional[int] = None
96-
97-
98100
@dataclass(frozen=True)
99101
class OrgCluster:
100102
org_name: str
101103
cluster_name: str
102104
balance: Balance
103105
quota: Quota
106+
default_credits: Optional[Decimal] = None
107+
default_quota: Quota = Quota()
104108

105109

106110
@unique

tests/conftest.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,70 @@ async def handle_org_list(
154154
return aiohttp.web.json_response(resp)
155155

156156
def _serialize_cluster(self, cluster: Cluster) -> dict[str, Any]:
157-
return {
157+
resp: dict[str, Any] = {
158158
"name": cluster.name,
159+
"default_quota": {},
159160
}
161+
if cluster.default_credits:
162+
resp["default_credits"] = str(cluster.default_credits)
163+
if cluster.default_quota.total_running_jobs:
164+
resp["default_quota"][
165+
"total_running_jobs"
166+
] = cluster.default_quota.total_running_jobs
167+
return resp
168+
169+
def _int_or_none(self, value: str | None) -> int | None:
170+
if value:
171+
return int(value)
172+
return None
160173

161174
async def handle_cluster_post(
162175
self, request: aiohttp.web.Request
163176
) -> aiohttp.web.Response:
164177
payload = await request.json()
178+
default_credits_raw = payload.get("default_credits")
179+
default_quota_raw = payload.get("default_quota", {})
165180
new_cluster = Cluster(
166181
name=payload["name"],
182+
default_credits=Decimal(default_credits_raw)
183+
if default_credits_raw
184+
else None,
185+
default_quota=Quota(
186+
total_running_jobs=self._int_or_none(
187+
default_quota_raw.get("total_running_jobs")
188+
)
189+
),
167190
)
168191
self.clusters.append(new_cluster)
169192
return aiohttp.web.json_response(self._serialize_cluster(new_cluster))
170193

194+
async def handle_cluster_put(
195+
self, request: aiohttp.web.Request
196+
) -> aiohttp.web.Response:
197+
cluster_name = request.match_info["cname"]
198+
payload = await request.json()
199+
200+
assert cluster_name == payload["name"]
201+
202+
default_credits_raw = payload.get("default_credits")
203+
default_quota_raw = payload.get("default_quota", {})
204+
changed_cluster = Cluster(
205+
name=payload["name"],
206+
default_credits=Decimal(default_credits_raw)
207+
if default_credits_raw
208+
else None,
209+
default_quota=Quota(
210+
total_running_jobs=self._int_or_none(
211+
default_quota_raw.get("total_running_jobs")
212+
)
213+
),
214+
)
215+
self.clusters = [
216+
cluster for cluster in self.clusters if cluster.name != changed_cluster.name
217+
]
218+
self.clusters.append(changed_cluster)
219+
return aiohttp.web.json_response(self._serialize_cluster(changed_cluster))
220+
171221
async def handle_cluster_get(
172222
self, request: aiohttp.web.Request
173223
) -> aiohttp.web.Response:
@@ -558,11 +608,18 @@ def _serialize_org_cluster(self, org_cluster: OrgCluster) -> dict[str, Any]:
558608
"balance": {
559609
"spent_credits": str(org_cluster.balance.spent_credits),
560610
},
611+
"default_quota": {},
561612
}
562613
if org_cluster.quota.total_running_jobs is not None:
563614
res["quota"]["total_running_jobs"] = org_cluster.quota.total_running_jobs
564615
if org_cluster.balance.credits is not None:
565616
res["balance"]["credits"] = str(org_cluster.balance.credits)
617+
if org_cluster.default_credits:
618+
res["default_credits"] = str(org_cluster.default_credits)
619+
if org_cluster.default_quota.total_running_jobs:
620+
res["default_quota"][
621+
"total_running_jobs"
622+
] = org_cluster.default_quota.total_running_jobs
566623
return res
567624

568625
async def handle_org_cluster_post(
@@ -571,6 +628,7 @@ async def handle_org_cluster_post(
571628
cluster_name = request.match_info["cname"]
572629
payload = await request.json()
573630
credits_raw = payload.get("balance", {}).get("credits")
631+
default_credits_raw = payload.get("default_credits")
574632
spend_credits_raw = payload.get("balance", {}).get("spend_credits_raw")
575633
new_org_cluster = OrgCluster(
576634
cluster_name=cluster_name,
@@ -584,6 +642,14 @@ async def handle_org_cluster_post(
584642
if spend_credits_raw
585643
else Decimal(0),
586644
),
645+
default_quota=Quota(
646+
total_running_jobs=payload.get("default_quota", {}).get(
647+
"total_running_jobs"
648+
)
649+
),
650+
default_credits=Decimal(default_credits_raw)
651+
if default_credits_raw
652+
else None,
587653
)
588654
self.org_clusters.append(new_org_cluster)
589655
return aiohttp.web.json_response(
@@ -599,6 +665,7 @@ async def handle_org_cluster_put(
599665
org_name = request.match_info["oname"]
600666
payload = await request.json()
601667
credits_raw = payload.get("balance", {}).get("credits")
668+
default_credits_raw = payload.get("default_credits")
602669
spend_credits_raw = payload.get("balance", {}).get("spend_credits_raw")
603670
new_org_cluster = OrgCluster(
604671
cluster_name=cluster_name,
@@ -612,6 +679,14 @@ async def handle_org_cluster_put(
612679
if spend_credits_raw
613680
else Decimal(0),
614681
),
682+
default_quota=Quota(
683+
total_running_jobs=payload.get("default_quota", {}).get(
684+
"total_running_jobs"
685+
)
686+
),
687+
default_credits=Decimal(default_credits_raw)
688+
if default_credits_raw
689+
else None,
615690
)
616691
assert new_org_cluster.org_name == org_name
617692
self.org_clusters = [
@@ -783,6 +858,10 @@ def _create_app() -> aiohttp.web.Application:
783858
"/api/v1/clusters/{cname}",
784859
admin_server.handle_cluster_get,
785860
),
861+
aiohttp.web.put(
862+
"/api/v1/clusters/{cname}",
863+
admin_server.handle_cluster_put,
864+
),
786865
aiohttp.web.delete(
787866
"/api/v1/clusters/{cname}",
788867
admin_server.handle_cluster_delete,

0 commit comments

Comments
 (0)