1010from datetime import datetime , timedelta
1111from threading import Event , Thread
1212from types import GeneratorType
13+ from enum import Enum
1314from typing import Any , Generator , Optional , Sequence , TypeVar , Union
15+ from packaging .version import InvalidVersion , parse
1416
1517import grpc
1618from google .protobuf import empty_pb2
@@ -72,9 +74,60 @@ def __init__(
7274 )
7375
7476
77+ class VersionMatchStrategy (Enum ):
78+ """Enumeration for version matching strategies."""
79+
80+ NONE = 1
81+ STRICT = 2
82+ CURRENT_OR_OLDER = 3
83+
84+
85+ class VersionFailureStrategy (Enum ):
86+ """Enumeration for version failure strategies."""
87+
88+ REJECT = 1
89+ FAIL = 2
90+
91+
92+ class VersionFailureException (Exception ):
93+ pass
94+
95+
96+ class VersioningOptions :
97+ """Configuration options for orchestrator and activity versioning.
98+
99+ This class provides options to control how versioning is handled for orchestrators
100+ and activities, including whether to use the default version and how to compare versions.
101+ """
102+
103+ version : Optional [str ] = None
104+ default_version : Optional [str ] = None
105+ match_strategy : Optional [VersionMatchStrategy ] = None
106+ failure_strategy : Optional [VersionFailureStrategy ] = None
107+
108+ def __init__ (self , version : Optional [str ] = None ,
109+ default_version : Optional [str ] = None ,
110+ match_strategy : Optional [VersionMatchStrategy ] = None ,
111+ failure_strategy : Optional [VersionFailureStrategy ] = None
112+ ):
113+ """Initialize versioning options.
114+
115+ Args:
116+ version: The specific version to use for orchestrators and activities.
117+ default_version: The default version to use if no specific version is provided.
118+ match_strategy: The strategy to use for matching versions.
119+ failure_strategy: The strategy to use if versioning fails.
120+ """
121+ self .version = version
122+ self .default_version = default_version
123+ self .match_strategy = match_strategy
124+ self .failure_strategy = failure_strategy
125+
126+
75127class _Registry :
76128 orchestrators : dict [str , task .Orchestrator ]
77129 activities : dict [str , task .Activity ]
130+ versioning : Optional [VersioningOptions ] = None
78131
79132 def __init__ (self ):
80133 self .orchestrators = {}
@@ -279,6 +332,12 @@ def add_activity(self, fn: task.Activity) -> str:
279332 )
280333 return self ._registry .add_activity (fn )
281334
335+ def use_versioning (self , version : VersioningOptions ) -> None :
336+ """Sets the default version for orchestrators and activities."""
337+ if self ._is_running :
338+ raise RuntimeError ("Cannot set default version while the worker is running." )
339+ self ._registry .versioning = version
340+
282341 def start (self ):
283342 """Starts the worker on a background thread and begins listening for work items."""
284343 if self ._is_running :
@@ -646,7 +705,7 @@ def set_complete(
646705 )
647706 self ._pending_actions [action .id ] = action
648707
649- def set_failed (self , ex : Exception ):
708+ def set_failed (self , ex : Union [ Exception , pb . TaskFailureDetails ] ):
650709 if self ._is_complete :
651710 return
652711
@@ -658,7 +717,7 @@ def set_failed(self, ex: Exception):
658717 self .next_sequence_number (),
659718 pb .ORCHESTRATION_STATUS_FAILED ,
660719 None ,
661- ph .new_failure_details (ex ),
720+ ph .new_failure_details (ex ) if isinstance ( ex , Exception ) else ex ,
662721 )
663722 self ._pending_actions [action .id ] = action
664723
@@ -768,6 +827,7 @@ def call_sub_orchestrator(
768827 input : Optional [TInput ] = None ,
769828 instance_id : Optional [str ] = None ,
770829 retry_policy : Optional [task .RetryPolicy ] = None ,
830+ version : Optional [str ] = None ,
771831 ) -> task .Task [TOutput ]:
772832 id = self .next_sequence_number ()
773833 orchestrator_name = task .get_name (orchestrator )
@@ -778,6 +838,7 @@ def call_sub_orchestrator(
778838 retry_policy = retry_policy ,
779839 is_sub_orch = True ,
780840 instance_id = instance_id ,
841+ version = version ,
781842 )
782843 return self ._pending_tasks .get (id , task .CompletableTask ())
783844
@@ -792,6 +853,7 @@ def call_activity_function_helper(
792853 is_sub_orch : bool = False ,
793854 instance_id : Optional [str ] = None ,
794855 fn_task : Optional [task .CompletableTask [TOutput ]] = None ,
856+ version : Optional [str ] = None ,
795857 ):
796858 if id is None :
797859 id = self .next_sequence_number ()
@@ -816,7 +878,7 @@ def call_activity_function_helper(
816878 if not isinstance (activity_function , str ):
817879 raise ValueError ("Orchestrator function name must be a string" )
818880 action = ph .new_create_sub_orchestration_action (
819- id , activity_function , instance_id , encoded_input
881+ id , activity_function , instance_id , encoded_input , version
820882 )
821883 self ._pending_actions [id ] = action
822884
@@ -893,7 +955,27 @@ def execute(
893955 )
894956
895957 ctx = _RuntimeOrchestrationContext (instance_id )
958+ version_failure = None
896959 try :
960+ execution_started_events = [e .executionStarted for e in old_events if e .HasField ("executionStarted" )]
961+ if self ._registry .versioning and len (execution_started_events ) > 0 :
962+ execution_started_event = execution_started_events [- 1 ]
963+ version_failure = self .evaluate_orchestration_versioning (
964+ self ._registry .versioning ,
965+ execution_started_event .version .value if execution_started_event .version else None ,
966+ )
967+ if version_failure :
968+ self ._logger .warning (
969+ f"Orchestration version did not meet worker versioning requirements. "
970+ f"Error action = '{ self ._registry .versioning .failure_strategy } '. "
971+ f"Version error = '{ version_failure } '"
972+ )
973+ if self ._registry .versioning .failure_strategy == VersionFailureStrategy .FAIL :
974+ raise VersionFailureException
975+ elif self ._registry .versioning .failure_strategy == VersionFailureStrategy .REJECT :
976+ # TODO: We don't have abandoned orchestrations yet, so we just fail
977+ raise VersionFailureException
978+
897979 # Rebuild local state by replaying old history into the orchestrator function
898980 self ._logger .debug (
899981 f"{ instance_id } : Rebuilding local state with { len (old_events )} history event..."
@@ -912,6 +994,12 @@ def execute(
912994 for new_event in new_events :
913995 self .process_event (ctx , new_event )
914996
997+ except VersionFailureException as ex :
998+ if version_failure :
999+ ctx .set_failed (version_failure )
1000+ else :
1001+ ctx .set_failed (ex )
1002+
9151003 except Exception as ex :
9161004 # Unhandled exceptions fail the orchestration
9171005 ctx .set_failed (ex )
@@ -1223,6 +1311,48 @@ def process_event(
12231311 # The orchestrator generator function completed
12241312 ctx .set_complete (generatorStopped .value , pb .ORCHESTRATION_STATUS_COMPLETED )
12251313
1314+ def evaluate_orchestration_versioning (self , versioning : Optional [VersioningOptions ], orchestration_version : Optional [str ]) -> Optional [pb .TaskFailureDetails ]:
1315+ if versioning is None :
1316+ return None
1317+ version_comparison = self .compare_versions (orchestration_version , versioning .version )
1318+ if versioning .match_strategy == VersionMatchStrategy .NONE :
1319+ return None
1320+ elif versioning .match_strategy == VersionMatchStrategy .STRICT :
1321+ if version_comparison != 0 :
1322+ return pb .TaskFailureDetails (
1323+ errorType = "VersionMismatch" ,
1324+ errorMessage = f"The orchestration version '{ orchestration_version } ' does not match the worker version '{ versioning .version } '." ,
1325+ isNonRetriable = True ,
1326+ )
1327+ elif versioning .match_strategy == VersionMatchStrategy .CURRENT_OR_OLDER :
1328+ if version_comparison > 0 :
1329+ return pb .TaskFailureDetails (
1330+ errorType = "VersionMismatch" ,
1331+ errorMessage = f"The orchestration version '{ orchestration_version } ' is greater than the worker version '{ versioning .version } '." ,
1332+ isNonRetriable = True ,
1333+ )
1334+ else :
1335+ # If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
1336+ return pb .TaskFailureDetails (
1337+ errorType = "VersionMismatch" ,
1338+ errorMessage = f"The version match strategy '{ versioning .match_strategy } ' is unknown." ,
1339+ isNonRetriable = True ,
1340+ )
1341+
1342+ def compare_versions (self , source_version : Optional [str ], default_version : Optional [str ]) -> int :
1343+ if not source_version and not default_version :
1344+ return 0
1345+ if not source_version :
1346+ return - 1
1347+ if not default_version :
1348+ return 1
1349+ try :
1350+ source_version_parsed = parse (source_version )
1351+ default_version_parsed = parse (default_version )
1352+ return (source_version_parsed > default_version_parsed ) - (source_version_parsed < default_version_parsed )
1353+ except InvalidVersion :
1354+ return (source_version > default_version ) - (source_version < default_version )
1355+
12261356
12271357class _ActivityExecutor :
12281358 def __init__ (self , registry : _Registry , logger : logging .Logger ):
0 commit comments