11from abc import ABC , abstractmethod
2- from typing import Annotated , ClassVar , Final , TypeAlias
2+ from typing import Annotated , ClassVar , Final , TypeAlias , TypedDict
33
44from fastapi import FastAPI
55from pydantic import Field , NonNegativeInt , TypeAdapter , validate_call
66
7- from ._errors import OperationAlreadyRegisteredError , OperationNotFoundError
7+ from ._errors import (
8+ OperationAlreadyRegisteredError ,
9+ OperationNotFoundError ,
10+ StepNotFoundInoperationError ,
11+ )
812from ._models import OperationName , StepGroupName , StepName
913
1014
@@ -98,8 +102,8 @@ def get_step_subgroup_to_run(self) -> StepsSubGroup:
98102
99103
100104@validate_call (config = {"arbitrary_types_allowed" : True })
101- def _validate_operation (operation : Operation ) -> None :
102- detected_steps_names : set [StepName ] = set ()
105+ def _validate_operation (operation : Operation ) -> dict [ StepName , type [ BaseStep ]] :
106+ detected_steps_names : dict [StepName , type [ BaseStep ]] = {}
103107
104108 for k , step_group in enumerate (operation ):
105109 if isinstance (step_group , ParallelStepGroup ):
@@ -121,30 +125,57 @@ def _validate_operation(operation: Operation) -> None:
121125 msg = f"Step { step_name = } is already used in this operation { detected_steps_names = } "
122126 raise ValueError (msg )
123127
124- detected_steps_names .add (step_name )
128+ detected_steps_names [step_name ] = step
129+
130+ return detected_steps_names
131+
132+
133+ class _UpdateScheduleDataDict (TypedDict ):
134+ operation : Operation
135+ steps : dict [StepName , type [BaseStep ]]
125136
126137
127138class OperationRegistry :
128- _OPERATIONS : ClassVar [dict [str , Operation ]] = {}
139+ _OPERATIONS : ClassVar [dict [OperationName , _UpdateScheduleDataDict ]] = {}
129140
130141 @classmethod
131142 def register (cls , operation_name : OperationName , operation : Operation ) -> None :
132- _validate_operation (operation )
143+ steps = _validate_operation (operation )
133144
134145 if operation_name in cls ._OPERATIONS :
135146 raise OperationAlreadyRegisteredError (operation_name = operation_name )
136147
137- cls ._OPERATIONS [operation_name ] = operation
148+ cls ._OPERATIONS [operation_name ] = {"operation" : operation , "steps" : steps }
149+
150+ @classmethod
151+ def get_operation (cls , operation_name : OperationName ) -> Operation :
152+ if operation_name not in cls ._OPERATIONS :
153+ raise OperationNotFoundError (
154+ operation_name = operation_name ,
155+ registerd_operations = list (cls ._OPERATIONS .keys ()),
156+ )
157+
158+ return cls ._OPERATIONS [operation_name ]["operation" ]
138159
139160 @classmethod
140- def get (cls , operation_name : OperationName ) -> Operation :
161+ def get_step (
162+ cls , operation_name : OperationName , step_name : StepName
163+ ) -> type [BaseStep ]:
141164 if operation_name not in cls ._OPERATIONS :
142165 raise OperationNotFoundError (
143166 operation_name = operation_name ,
144167 registerd_operations = list (cls ._OPERATIONS .keys ()),
145168 )
146169
147- return cls ._OPERATIONS [operation_name ]
170+ steps_names = list (cls ._OPERATIONS [operation_name ]["steps" ].keys ())
171+ if step_name not in steps_names :
172+ raise StepNotFoundInoperationError (
173+ step_name = step_name ,
174+ operation_name = operation_name ,
175+ steps_names = steps_names ,
176+ )
177+
178+ return cls ._OPERATIONS [operation_name ]["steps" ][step_name ]
148179
149180 @classmethod
150181 def unregister (cls , operation_name : OperationName ) -> None :
0 commit comments