6363 get_or_create_func_info ,
6464 get_temp_workflow_type ,
6565 set_dbos_func_name ,
66+ set_func_info ,
6667 set_temp_workflow_type ,
6768)
6869from ._roles import check_required_roles
@@ -286,6 +287,7 @@ def execute_workflow_by_id(
286287 ctx .request = (
287288 _serialization .deserialize (request ) if request is not None else None
288289 )
290+ # If this function belongs to a configured class, add that class instance as its first argument
289291 if status ["config_name" ] is not None :
290292 config_name = status ["config_name" ]
291293 class_name = status ["class_name" ]
@@ -295,59 +297,30 @@ def execute_workflow_by_id(
295297 workflow_id ,
296298 f"Cannot execute workflow because instance '{ iname } ' is not registered" ,
297299 )
298-
299- if startNew :
300- return start_workflow (
301- dbos ,
302- wf_func ,
303- status ["queue_name" ],
304- True ,
305- dbos ._registry .instance_info_map [iname ],
306- * inputs ["args" ],
307- ** inputs ["kwargs" ],
308- )
309- else :
310- with SetWorkflowID (workflow_id ):
311- return start_workflow (
312- dbos ,
313- wf_func ,
314- status ["queue_name" ],
315- True ,
316- dbos ._registry .instance_info_map [iname ],
317- * inputs ["args" ],
318- ** inputs ["kwargs" ],
319- )
300+ class_instance = dbos ._registry .instance_info_map [iname ]
301+ inputs ["args" ] = (class_instance ,) + inputs ["args" ]
302+ # If this function is a class method, add that class object as its first argument
320303 elif status ["class_name" ] is not None :
321304 class_name = status ["class_name" ]
322305 if class_name not in dbos ._registry .class_info_map :
323306 raise DBOSWorkflowFunctionNotFoundError (
324307 workflow_id ,
325308 f"Cannot execute workflow because class '{ class_name } ' is not registered" ,
326309 )
310+ class_object = dbos ._registry .class_info_map [class_name ]
311+ inputs ["args" ] = (class_object ,) + inputs ["args" ]
327312
328- if startNew :
329- return start_workflow (
330- dbos ,
331- wf_func ,
332- status ["queue_name" ],
333- True ,
334- dbos ._registry .class_info_map [class_name ],
335- * inputs ["args" ],
336- ** inputs ["kwargs" ],
337- )
338- else :
339- with SetWorkflowID (workflow_id ):
340- return start_workflow (
341- dbos ,
342- wf_func ,
343- status ["queue_name" ],
344- True ,
345- dbos ._registry .class_info_map [class_name ],
346- * inputs ["args" ],
347- ** inputs ["kwargs" ],
348- )
313+ if startNew :
314+ return start_workflow (
315+ dbos ,
316+ wf_func ,
317+ status ["queue_name" ],
318+ True ,
319+ * inputs ["args" ],
320+ ** inputs ["kwargs" ],
321+ )
349322 else :
350- if startNew :
323+ with SetWorkflowID ( workflow_id ) :
351324 return start_workflow (
352325 dbos ,
353326 wf_func ,
@@ -356,16 +329,6 @@ def execute_workflow_by_id(
356329 * inputs ["args" ],
357330 ** inputs ["kwargs" ],
358331 )
359- else :
360- with SetWorkflowID (workflow_id ):
361- return start_workflow (
362- dbos ,
363- wf_func ,
364- status ["queue_name" ],
365- True ,
366- * inputs ["args" ],
367- ** inputs ["kwargs" ],
368- )
369332
370333
371334@overload
@@ -398,9 +361,12 @@ def start_workflow(
398361 * args : P .args ,
399362 ** kwargs : P .kwargs ,
400363) -> "WorkflowHandle[R]" :
364+ # If the function has a class, add the class object as its first argument
401365 fself : Optional [object ] = None
402366 if hasattr (func , "__self__" ):
403367 fself = func .__self__
368+ if fself is not None :
369+ args = (fself ,) + args # type: ignore
404370
405371 fi = get_func_info (func )
406372 if fi is None :
@@ -436,17 +402,13 @@ def start_workflow(
436402 new_wf_ctx .id_assigned_for_next_workflow = new_wf_ctx .assign_workflow_id ()
437403 new_wf_id = new_wf_ctx .id_assigned_for_next_workflow
438404
439- gin_args : Tuple [Any , ...] = args
440- if fself is not None :
441- gin_args = (fself ,)
442-
443405 status = _init_workflow (
444406 dbos ,
445407 new_wf_ctx ,
446408 inputs = inputs ,
447409 wf_name = get_dbos_func_name (func ),
448- class_name = get_dbos_class_name (fi , func , gin_args ),
449- config_name = get_config_name (fi , func , gin_args ),
410+ class_name = get_dbos_class_name (fi , func , args ),
411+ config_name = get_config_name (fi , func , args ),
450412 temp_wf_type = get_temp_workflow_type (func ),
451413 queue = queue_name ,
452414 max_recovery_attempts = fi .max_recovery_attempts ,
@@ -464,27 +426,15 @@ def start_workflow(
464426 )
465427 return WorkflowHandlePolling (new_wf_id , dbos )
466428
467- if fself is not None :
468- future = dbos ._executor .submit (
469- cast (Callable [..., R ], _execute_workflow_wthread ),
470- dbos ,
471- status ,
472- func ,
473- new_wf_ctx ,
474- fself ,
475- * args ,
476- ** kwargs ,
477- )
478- else :
479- future = dbos ._executor .submit (
480- cast (Callable [..., R ], _execute_workflow_wthread ),
481- dbos ,
482- status ,
483- func ,
484- new_wf_ctx ,
485- * args ,
486- ** kwargs ,
487- )
429+ future = dbos ._executor .submit (
430+ cast (Callable [..., R ], _execute_workflow_wthread ),
431+ dbos ,
432+ status ,
433+ func ,
434+ new_wf_ctx ,
435+ * args ,
436+ ** kwargs ,
437+ )
488438 return WorkflowHandleFuture (new_wf_id , future , dbos )
489439
490440
@@ -516,6 +466,8 @@ def workflow_wrapper(
516466
517467 @wraps (func )
518468 def wrapper (* args : Any , ** kwargs : Any ) -> R :
469+ fi = get_func_info (func )
470+ assert fi is not None
519471 if dbosreg .dbos is None :
520472 raise DBOSException (
521473 f"Function { func .__name__ } invoked before DBOS initialized"
@@ -726,6 +678,8 @@ def temp_wf(*args: Any, **kwargs: Any) -> Any:
726678 set_temp_workflow_type (temp_wf , "transaction" )
727679 dbosreg .register_wf_function (get_dbos_func_name (temp_wf ), wrapped_wf )
728680 wrapper .__orig_func = temp_wf # type: ignore
681+ set_func_info (wrapped_wf , get_or_create_func_info (func ))
682+ set_func_info (temp_wf , get_or_create_func_info (func ))
729683
730684 return cast (F , wrapper )
731685
@@ -875,6 +829,8 @@ async def temp_wf_async(*args: Any, **kwargs: Any) -> Any:
875829 set_temp_workflow_type (temp_wf , "step" )
876830 dbosreg .register_wf_function (get_dbos_func_name (temp_wf ), wrapped_wf )
877831 wrapper .__orig_func = temp_wf # type: ignore
832+ set_func_info (wrapped_wf , get_or_create_func_info (func ))
833+ set_func_info (temp_wf , get_or_create_func_info (func ))
878834
879835 return cast (Callable [P , R ], wrapper )
880836
0 commit comments