Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 7498c63

Browse files
committed
fix: user-facing enum for concurrency
1 parent 544d65b commit 7498c63

File tree

1 file changed

+46
-23
lines changed

1 file changed

+46
-23
lines changed

hatchet_sdk/v2/workflows.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
TypeVar,
1212
Union,
1313
cast,
14-
overload,
1514
)
1615

1716
from pydantic import BaseModel, ConfigDict
1817

19-
from hatchet_sdk.clients.rest_client import RestApi
2018
from hatchet_sdk.context.context import Context
2119
from hatchet_sdk.contracts.workflows_pb2 import (
22-
ConcurrencyLimitStrategy,
20+
ConcurrencyLimitStrategy as ConcurrencyLimitStrategyProto,
21+
)
22+
from hatchet_sdk.contracts.workflows_pb2 import (
2323
CreateStepRateLimit,
2424
CreateWorkflowJobOpts,
2525
CreateWorkflowStepOpts,
@@ -28,14 +28,30 @@
2828
)
2929
from hatchet_sdk.contracts.workflows_pb2 import StickyStrategy as StickyStrategyProto
3030
from hatchet_sdk.contracts.workflows_pb2 import WorkflowConcurrencyOpts, WorkflowKind
31-
32-
from ..logger import logger
31+
from hatchet_sdk.logger import logger
3332

3433
R = TypeVar("R")
3534
P = ParamSpec("P")
3635

3736

38-
class ConcurrencyExpression:
37+
class EmptyModel(BaseModel):
38+
model_config = ConfigDict(extra="allow")
39+
40+
41+
class StickyStrategy(str, Enum):
42+
SOFT = "SOFT"
43+
HARD = "HARD"
44+
45+
46+
class ConcurrencyLimitStrategy(str, Enum):
47+
CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS"
48+
DROP_NEWEST = "DROP_NEWEST"
49+
QUEUE_NEWEST = "QUEUE_NEWEST"
50+
GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN"
51+
CANCEL_NEWEST = "CANCEL_NEWEST"
52+
53+
54+
class ConcurrencyExpression(BaseModel):
3955
"""
4056
Defines concurrency limits for a workflow using a CEL expression.
4157
@@ -48,21 +64,9 @@ class ConcurrencyExpression:
4864
ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
4965
"""
5066

51-
def __init__(
52-
self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy
53-
):
54-
self.expression = expression
55-
self.max_runs = max_runs
56-
self.limit_strategy = limit_strategy
57-
58-
59-
class EmptyModel(BaseModel):
60-
model_config = ConfigDict(extra="allow")
61-
62-
63-
class StickyStrategy(str, Enum):
64-
SOFT = "SOFT"
65-
HARD = "HARD"
67+
expression: str
68+
max_runs: int
69+
limit_strategy: ConcurrencyLimitStrategy
6670

6771

6872
class WorkflowConfig(BaseModel):
@@ -204,7 +208,10 @@ def validate_concurrency_actions(
204208
return WorkflowConcurrencyOpts(
205209
action=service_name + ":" + action.name,
206210
max_runs=action.concurrency__max_runs,
207-
limit_strategy=action.concurrency__limit_strategy,
211+
limit_strategy=cast(
212+
str | None,
213+
self.validate_concurrency(action.concurrency__limit_strategy),
214+
),
208215
)
209216

210217
if self.config.concurrency:
@@ -252,6 +259,22 @@ def validate_priority(self, default_priority: int | None) -> int | None:
252259

253260
return validated_priority
254261

262+
def validate_concurrency(
263+
self, concurrency: ConcurrencyLimitStrategy | None
264+
) -> int | None:
265+
if not concurrency:
266+
return None
267+
268+
names = [item.name for item in ConcurrencyLimitStrategyProto.DESCRIPTOR.values]
269+
270+
for name in names:
271+
if name == concurrency.name:
272+
return StickyStrategyProto.Value(concurrency.name)
273+
274+
raise ValueError(
275+
f"Concurrency limit strategy must be one of {names}. Got: {concurrency}"
276+
)
277+
255278
def validate_sticky(self, sticky: StickyStrategy | None) -> int | None:
256279
if not sticky:
257280
return None
@@ -298,7 +321,7 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts:
298321
event_triggers=event_triggers,
299322
cron_triggers=self.config.on_crons,
300323
schedule_timeout=self.config.schedule_timeout,
301-
sticky=cast(str, self.validate_sticky(self.config.sticky)),
324+
sticky=cast(str | None, self.validate_sticky(self.config.sticky)),
302325
jobs=[
303326
CreateWorkflowJobOpts(
304327
name=name,

0 commit comments

Comments
 (0)