1111 TypeVar ,
1212 Union ,
1313 cast ,
14- overload ,
1514)
1615
1716from pydantic import BaseModel , ConfigDict
1817
19- from hatchet_sdk .clients .rest_client import RestApi
2018from hatchet_sdk .context .context import Context
2119from hatchet_sdk .contracts .workflows_pb2 import (
22- ConcurrencyLimitStrategy ,
20+ ConcurrencyLimitStrategy as ConcurrencyLimitStrategyProto ,
21+ )
22+ from hatchet_sdk .contracts .workflows_pb2 import (
2323 CreateStepRateLimit ,
2424 CreateWorkflowJobOpts ,
2525 CreateWorkflowStepOpts ,
2828)
2929from hatchet_sdk .contracts .workflows_pb2 import StickyStrategy as StickyStrategyProto
3030from hatchet_sdk .contracts .workflows_pb2 import WorkflowConcurrencyOpts , WorkflowKind
31-
32- from ..logger import logger
31+ from hatchet_sdk .logger import logger
3332
3433R = TypeVar ("R" )
3534P = ParamSpec ("P" )
3635
3736
38- class ConcurrencyExpression :
37+ class EmptyModel (BaseModel ):
38+ model_config = ConfigDict (extra = "allow" )
39+
40+
41+ class StickyStrategy (str , Enum ):
42+ SOFT = "SOFT"
43+ HARD = "HARD"
44+
45+
46+ class ConcurrencyLimitStrategy (str , Enum ):
47+ CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS"
48+ DROP_NEWEST = "DROP_NEWEST"
49+ QUEUE_NEWEST = "QUEUE_NEWEST"
50+ GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN"
51+ CANCEL_NEWEST = "CANCEL_NEWEST"
52+
53+
54+ class ConcurrencyExpression (BaseModel ):
3955 """
4056 Defines concurrency limits for a workflow using a CEL expression.
4157
@@ -48,21 +64,9 @@ class ConcurrencyExpression:
4864 ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
4965 """
5066
51- def __init__ (
52- self , expression : str , max_runs : int , limit_strategy : ConcurrencyLimitStrategy
53- ):
54- self .expression = expression
55- self .max_runs = max_runs
56- self .limit_strategy = limit_strategy
57-
58-
59- class EmptyModel (BaseModel ):
60- model_config = ConfigDict (extra = "allow" )
61-
62-
63- class StickyStrategy (str , Enum ):
64- SOFT = "SOFT"
65- HARD = "HARD"
67+ expression : str
68+ max_runs : int
69+ limit_strategy : ConcurrencyLimitStrategy
6670
6771
6872class WorkflowConfig (BaseModel ):
@@ -204,7 +208,10 @@ def validate_concurrency_actions(
204208 return WorkflowConcurrencyOpts (
205209 action = service_name + ":" + action .name ,
206210 max_runs = action .concurrency__max_runs ,
207- limit_strategy = action .concurrency__limit_strategy ,
211+ limit_strategy = cast (
212+ str | None ,
213+ self .validate_concurrency (action .concurrency__limit_strategy ),
214+ ),
208215 )
209216
210217 if self .config .concurrency :
@@ -252,6 +259,22 @@ def validate_priority(self, default_priority: int | None) -> int | None:
252259
253260 return validated_priority
254261
262+ def validate_concurrency (
263+ self , concurrency : ConcurrencyLimitStrategy | None
264+ ) -> int | None :
265+ if not concurrency :
266+ return None
267+
268+ names = [item .name for item in ConcurrencyLimitStrategyProto .DESCRIPTOR .values ]
269+
270+ for name in names :
271+ if name == concurrency .name :
272+ return StickyStrategyProto .Value (concurrency .name )
273+
274+ raise ValueError (
275+ f"Concurrency limit strategy must be one of { names } . Got: { concurrency } "
276+ )
277+
255278 def validate_sticky (self , sticky : StickyStrategy | None ) -> int | None :
256279 if not sticky :
257280 return None
@@ -298,7 +321,7 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts:
298321 event_triggers = event_triggers ,
299322 cron_triggers = self .config .on_crons ,
300323 schedule_timeout = self .config .schedule_timeout ,
301- sticky = cast (str , self .validate_sticky (self .config .sticky )),
324+ sticky = cast (str | None , self .validate_sticky (self .config .sticky )),
302325 jobs = [
303326 CreateWorkflowJobOpts (
304327 name = name ,
0 commit comments