11import asyncio
22from dataclasses import dataclass , field
3+ from types import MappingProxyType
34from typing import Any , Callable , Dict , Optional
45import inspect
56
@@ -131,7 +132,16 @@ async def execute_task_instance(
131132
132133 # Prepare arguments using the TaskRun's context/inputs determined by the Orchestrator
133134 # The Orchestrator should have resolved all upstream tasks' data into task_run.context
134- pipeline_params = task_run .context .get ("params" , None ) if task_run .context else None
135+ if task_run .context :
136+ dict_params = task_run .context .get ("params" , None )
137+
138+ if pipeline .params :
139+ pipeline_params = pipeline .params .model_validate (dict_params )
140+ else :
141+ # TODO: This should raise at least a warning
142+ pipeline_params = dict_params
143+ else :
144+ pipeline_params = None
135145
136146 task_start_time = utcnow ()
137147 task_run_status = PipelineRunStatus .FAILED # Assume failure until success
@@ -228,6 +238,7 @@ async def run(
228238
229239@dataclass
230240class TaskFunctionSignature :
241+ func_params : MappingProxyType [str , inspect .Parameter ]
231242 has_params_arg : bool = False
232243 context_arg : Optional [str ] = None
233244 input_arg_names : list [str ] = field (default_factory = list )
@@ -245,9 +256,9 @@ def check_task_signature(func: Callable) -> TaskFunctionSignature:
245256 Where the params argument is the Pipeline input params.
246257 """
247258
248- result = TaskFunctionSignature ()
259+ result = TaskFunctionSignature (inspect . signature ( func ). parameters )
249260
250- for name , parameter in inspect . signature ( func ). parameters .items ():
261+ for name , parameter in result . func_params .items ():
251262 # Check for special arguments
252263 if name == "params" :
253264 result .has_params_arg = True
@@ -294,11 +305,18 @@ async def _execute_task(
294305
295306 # Load the TaskRuns for all upstream dependencies
296307 upstream_runs_metadata = get_task_runs_for_pipeline_run (
297- task_run .pipeline_run_id , task .upstream_task_ids
308+ task_run .pipeline_run_id , task_ids = task .upstream_task_ids
298309 )
299310
300311 # Build the map of task_id -> TaskRun model instance
301- metadata_map = {run .task_id : run for run in upstream_runs_metadata }
312+ metadata_map = {
313+ (
314+ f"{ run .task_id } .{ run .map_index } "
315+ if run .map_index is not None
316+ else run .task_id
317+ ): run
318+ for run in upstream_runs_metadata
319+ }
302320 runtime_context = Context (task_run , metadata_map )
303321
304322 # Iterate over arguments required by the function signature
@@ -307,6 +325,13 @@ async def _execute_task(
307325 # - If mapped, resolves to single item if arg_name == map_upstream_id.
308326 # - Otherwise, resolves to the full output of the upstream task named arg_name.
309327 input_data = runtime_context .get_output_data (task_id = arg_name )
328+
329+ arg_annotation = result .func_params [arg_name ].annotation
330+
331+ # If the argument is a Pydantic Model, we parse it
332+ if issubclass (arg_annotation , BaseModel ):
333+ input_data = arg_annotation .model_validate (input_data )
334+
310335 kwargs [arg_name ] = input_data
311336
312337 if pipeline_params and result .has_params_arg :
0 commit comments