|
12 | 12 | from typing import ( |
13 | 13 | Dict, |
14 | 14 | List, |
15 | | - Literal, |
16 | 15 | NamedTuple, |
17 | 16 | Optional, |
18 | 17 | ) |
|
22 | 21 | Field, |
23 | 22 | ) |
24 | 23 | from pydantictes.models import TesResources |
| 24 | +from typing_extensions import Literal |
| 25 | + |
25 | 26 | from pulsar.managers.util.gcp_util import ( |
26 | 27 | batch_v1, |
27 | 28 | ensure_client as ensure_gcp_client, |
@@ -197,7 +198,7 @@ def gcp_job_template(params: GcpJobParams) -> "batch_v1.Job": |
197 | 198 | job.labels = params.labels or {} |
198 | 199 | # We use Cloud Logging as it's an out of the box available option |
199 | 200 | job.logs_policy = batch_v1.LogsPolicy() |
200 | | - job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING |
| 201 | + job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING # type: ignore[assignment] |
201 | 202 |
|
202 | 203 | return job |
203 | 204 |
|
@@ -230,19 +231,12 @@ class BasicAuth(BaseModel): |
230 | 231 | password: str = Field(..., description="Password for basic authentication.") |
231 | 232 |
|
232 | 233 |
|
233 | | -class TesJobParams(BaseModel): |
| 234 | +class TesJobParams(TesResources): |
234 | 235 | tes_url: str = Field(..., description="URL of the TES service.") |
235 | 236 | authorization: Literal["none", "basic"] = Field( |
236 | 237 | "none", description="Authorization type for TES service." |
237 | 238 | ) |
238 | 239 | basic_auth: Optional[BasicAuth] = Field(None, description="Authorization for TES service.") |
239 | | - cpu_cores: Optional[int] = TesResources.__pydantic_fields__["cpu_cores"] |
240 | | - preemptible: Optional[bool] = TesResources.__pydantic_fields__["preemptible"] |
241 | | - ram_gb: Optional[float] = TesResources.__pydantic_fields__["ram_gb"] |
242 | | - disk_gb: Optional[float] = TesResources.__pydantic_fields__["disk_gb"] |
243 | | - zones: Optional[List[str]] = TesResources.__pydantic_fields__["zones"] |
244 | | - backend_parameters: Optional[Dict[str, str]] = TesResources.__pydantic_fields__["backend_parameters"] |
245 | | - backend_parameters_strict: Optional[bool] = TesResources.__pydantic_fields__["backend_parameters_strict"] |
246 | 240 |
|
247 | 241 |
|
248 | 242 | def parse_tes_job_params(params: dict) -> TesJobParams: |
@@ -290,19 +284,5 @@ def tes_client_from_params(tes_params: TesJobParams) -> TesClient: |
290 | 284 |
|
291 | 285 |
|
292 | 286 | def tes_resources(tes_params: TesJobParams) -> TesResources: |
293 | | - cpu_cores: Optional[int] = tes_params.cpu_cores |
294 | | - preemptible: Optional[bool] = tes_params.preemptible |
295 | | - ram_gb: Optional[float] = tes_params.ram_gb |
296 | | - disk_gb: Optional[float] = tes_params.disk_gb |
297 | | - zones: Optional[List[str]] = tes_params.zones |
298 | | - backend_parameters: Optional[Dict[str, str]] = tes_params.backend_parameters |
299 | | - backend_parameters_strict: Optional[bool] = tes_params.backend_parameters_strict |
300 | | - return TesResources( |
301 | | - cpu_cores=cpu_cores, |
302 | | - preemptible=preemptible, |
303 | | - ram_gb=ram_gb, |
304 | | - disk_gb=disk_gb, |
305 | | - zones=zones, |
306 | | - backend_parameters=backend_parameters, |
307 | | - backend_parameters_strict=backend_parameters_strict, |
308 | | - ) |
| 287 | + # TesJobParams subclasses it so just pass through as is. |
| 288 | + return tes_params |
0 commit comments