Skip to content

Commit 30787f7

Browse files
fix context resolution for mapped tasks
1 parent f63fb7d commit 30787f7

File tree

5 files changed

+63
-10
lines changed

5 files changed

+63
-10
lines changed

frontend/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "plombery",
3-
"version": "0.5.2-beta2",
3+
"version": "0.6.0-beta1",
44
"description": "",
55
"license": "MIT",
66
"author": {

src/plombery/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Keep it aligned with version in package.json
22

3-
__version__ = "0.5.2-beta2"
3+
__version__ = "0.6.0-beta1"

src/plombery/database/repository.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,17 @@ def create_task_run_output(
187187
) -> models.TaskRunOutput:
188188
"""Creates a new TaskRunOutput record and returns the instance."""
189189
with SessionLocal() as session:
190+
data = (
191+
task_output.data.__dict__
192+
if hasattr(task_output.data, "__dict__")
193+
else task_output.data
194+
)
195+
190196
db_output = models.TaskRunOutput(
191-
**task_output.model_dump(), size=len(task_output.data)
197+
mimetype=task_output.mimetype,
198+
encoding=task_output.encoding,
199+
data=data,
200+
size=0,
192201
)
193202
session.add(db_output)
194203
session.flush()

src/plombery/orchestrator/context.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from plombery.database.repository import get_task_run_output_by_id
33
from plombery.database.models import TaskRun
44
from plombery.logger import get_logger
5+
from plombery.pipeline.context import task_context
6+
from plombery.pipeline.tasks import MappingMode
57

68

79
class Context:
@@ -13,13 +15,25 @@ def __init__(self, _task_run: TaskRun, upstream_task_runs: dict[str, TaskRun]):
1315
self._task_run = _task_run
1416
self._upstream_task_runs = upstream_task_runs
1517
self.logger = get_logger()
18+
self.task = task_context.get()
1619

1720
def get_output_data(self, task_id: str) -> Optional[Any]:
1821
"""
1922
Imperatively retrieves the full TaskRunOutput data (XCom) for a specified
2023
upstream task, performing a database lookup only upon call.
2124
"""
22-
target_task_run = self._upstream_task_runs.get(task_id)
25+
if (
26+
self._task_run.map_index is not None
27+
# if it's a Chained Fan Out the upstream is returning a
28+
# primitive value and not an array, si we retrieve the value of
29+
# the corresponding task + map_index
30+
and self.task.mapping_mode == MappingMode.CHAINED_FAN_OUT
31+
):
32+
task_full_id = f"{task_id}.{self._task_run.map_index}"
33+
else:
34+
task_full_id = task_id
35+
36+
target_task_run = self._upstream_task_runs.get(task_full_id)
2337

2438
if not target_task_run or not target_task_run.task_output_id:
2539
# Task ID not found in upstream dependencies
@@ -29,7 +43,12 @@ def get_output_data(self, task_id: str) -> Optional[Any]:
2943
output_record = get_task_run_output_by_id(target_task_run.task_output_id)
3044

3145
if output_record:
32-
if self._task_run.map_index is not None:
46+
if (
47+
self._task_run.map_index is not None
48+
# If it's a Fan Out task then the upstream is return a list and we need
49+
# to get the item at the specific index
50+
and self.task.mapping_mode == MappingMode.FAN_OUT
51+
):
3352
return output_record.data[self._task_run.map_index]
3453

3554
# Return the data stored in the 'data' JSON column

src/plombery/orchestrator/executor.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from dataclasses import dataclass, field
3+
from types import MappingProxyType
34
from typing import Any, Callable, Dict, Optional
45
import inspect
56

@@ -131,7 +132,16 @@ async def execute_task_instance(
131132

132133
# Prepare arguments using the TaskRun's context/inputs determined by the Orchestrator
133134
# The Orchestrator should have resolved all upstream tasks' data into task_run.context
134-
pipeline_params = task_run.context.get("params", None) if task_run.context else None
135+
if task_run.context:
136+
dict_params = task_run.context.get("params", None)
137+
138+
if pipeline.params:
139+
pipeline_params = pipeline.params.model_validate(dict_params)
140+
else:
141+
# TODO: This should raise at least a warning
142+
pipeline_params = dict_params
143+
else:
144+
pipeline_params = None
135145

136146
task_start_time = utcnow()
137147
task_run_status = PipelineRunStatus.FAILED # Assume failure until success
@@ -228,6 +238,7 @@ async def run(
228238

229239
@dataclass
230240
class TaskFunctionSignature:
241+
func_params: MappingProxyType[str, inspect.Parameter]
231242
has_params_arg: bool = False
232243
context_arg: Optional[str] = None
233244
input_arg_names: list[str] = field(default_factory=list)
@@ -245,9 +256,9 @@ def check_task_signature(func: Callable) -> TaskFunctionSignature:
245256
Where the params argument is the Pipeline input params.
246257
"""
247258

248-
result = TaskFunctionSignature()
259+
result = TaskFunctionSignature(inspect.signature(func).parameters)
249260

250-
for name, parameter in inspect.signature(func).parameters.items():
261+
for name, parameter in result.func_params.items():
251262
# Check for special arguments
252263
if name == "params":
253264
result.has_params_arg = True
@@ -294,11 +305,18 @@ async def _execute_task(
294305

295306
# Load the TaskRuns for all upstream dependencies
296307
upstream_runs_metadata = get_task_runs_for_pipeline_run(
297-
task_run.pipeline_run_id, task.upstream_task_ids
308+
task_run.pipeline_run_id, task_ids=task.upstream_task_ids
298309
)
299310

300311
# Build the map of task_id -> TaskRun model instance
301-
metadata_map = {run.task_id: run for run in upstream_runs_metadata}
312+
metadata_map = {
313+
(
314+
f"{run.task_id}.{run.map_index}"
315+
if run.map_index is not None
316+
else run.task_id
317+
): run
318+
for run in upstream_runs_metadata
319+
}
302320
runtime_context = Context(task_run, metadata_map)
303321

304322
# Iterate over arguments required by the function signature
@@ -307,6 +325,13 @@ async def _execute_task(
307325
# - If mapped, resolves to single item if arg_name == map_upstream_id.
308326
# - Otherwise, resolves to the full output of the upstream task named arg_name.
309327
input_data = runtime_context.get_output_data(task_id=arg_name)
328+
329+
arg_annotation = result.func_params[arg_name].annotation
330+
331+
# If the argument is a Pydantic Model, we parse it
332+
if issubclass(arg_annotation, BaseModel):
333+
input_data = arg_annotation.model_validate(input_data)
334+
310335
kwargs[arg_name] = input_data
311336

312337
if pipeline_params and result.has_params_arg:

0 commit comments

Comments
 (0)