|
1 | 1 | import asyncio |
2 | 2 | from enum import Enum |
3 | | -from typing import Any, Callable, Generic, ParamSpec, Type, TypeVar, Union |
| 3 | +from typing import Any, Callable, Generic, ParamSpec, Type, TypeVar, Union, cast |
4 | 4 |
|
5 | 5 | from pydantic import BaseModel, ConfigDict |
6 | 6 |
|
@@ -63,7 +63,7 @@ class WorkflowConfig(BaseModel): |
63 | 63 | version: str = "" |
64 | 64 | timeout: str = "60m" |
65 | 65 | schedule_timeout: str = "5m" |
66 | | - sticky: Union[StickyStrategy, None] = None |
| 66 | + sticky: StickyStrategy | None = None |
67 | 67 | default_priority: int = 1 |
68 | 68 | concurrency: ConcurrencyExpression | None = None |
69 | 69 | input_validator: Type[BaseModel] = EmptyModel |
@@ -208,12 +208,17 @@ def validate_priority(self, default_priority: int | None) -> int | None: |
208 | 208 |
|
209 | 209 | return validated_priority |
210 | 210 |
|
211 | | - def validate_sticky( |
212 | | - self, sticky: Union[StickyStrategy, None] |
213 | | - ) -> StickyStrategyProto | None: |
214 | | - if sticky: |
215 | | - return StickyStrategyProto(sticky) |
216 | | - return None |
| 211 | + def validate_sticky(self, sticky: StickyStrategy | None) -> int | None: |
| 212 | + if not sticky: |
| 213 | + return None |
| 214 | + |
| 215 | + names = [item.name for item in StickyStrategyProto.DESCRIPTOR.values] |
| 216 | + |
| 217 | + for name in names: |
| 218 | + if name == sticky.name: |
| 219 | + return StickyStrategyProto.Value(sticky.name) |
| 220 | + |
| 221 | + raise ValueError(f"Sticky strategy must be one of {names}. Got: {sticky}") |
217 | 222 |
|
218 | 223 | def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: |
219 | 224 | service_name = self.get_service_name(namespace) |
@@ -249,7 +254,7 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: |
249 | 254 | event_triggers=event_triggers, |
250 | 255 | cron_triggers=self.config.on_crons, |
251 | 256 | schedule_timeout=self.config.schedule_timeout, |
252 | | - sticky=self.validate_sticky(self.config.sticky), |
| 257 | + sticky=cast(str, self.validate_sticky(self.config.sticky)), |
253 | 258 | jobs=[ |
254 | 259 | CreateWorkflowJobOpts( |
255 | 260 | name=name, |
|
0 commit comments