Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Overview
* utilize `just` for command running -- `just val` in backend is the "typechecking and testing". Always run this after you make any changes to python code
* project is managed by `uv` -- utilize that for running any python-related subcommands like `uv run pytest` or `uv run ty` for typechecking
* there are two python modules -- cascade which is a low level execution engine, and earthkit.workflows which is a higher level abstraction on top of it. Each has its own subdirectory in tests
* always use type annotations, it is enforced
* when working with a package with bad typing coverage like sqlalchemy, use ty:ignore comment
* when ty is not powerful enough, use ty:ignore
* use typing.cast when the code logic is implicitly erasing the type information
* prioritize using pydantic.BaseModel or dataclasses.dataclass object for capturing contracts and interfaces.
* ideally keep them plain, stateless, frozen, without functions -- we end up serializing those objects often over to other python processes or different languages
26 changes: 15 additions & 11 deletions src/cascade/executor/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import logging
from collections.abc import Generator
from dataclasses import dataclass
from time import perf_counter_ns
from typing import Any, Callable
Expand Down Expand Up @@ -81,19 +82,9 @@ def run(taskId: TaskId, executionContext: ExecutionContext, memory: Memory) -> N

# invoke
result = func(*args, **kwargs)
if outputsN == 1:
mark({"task": taskId, "action": TaskLifecycle.computed})
run_end = perf_counter_ns()

# store outputs
if outputsN == 1:
outputKey, outputSchema = outputs[0]
outputId = DatasetId(taskId, outputKey)
memory.handle(
outputId, outputSchema, result, outputId in executionContext.publish
)
mark({"task": taskId, "action": TaskLifecycle.published})
else:
if isinstance(result, Generator):
outputsI = iter(outputs)
for (outputKey, outputSchema), outputValue in zip(outputsI, result):
outputId = DatasetId(taskId, outputKey)
Expand All @@ -113,6 +104,19 @@ def run(taskId: TaskId, executionContext: ExecutionContext, memory: Memory) -> N
mark({"task": taskId, "action": TaskLifecycle.computed})
run_end = perf_counter_ns()
mark({"task": taskId, "action": TaskLifecycle.published})
else:
if outputsN != 1:
raise ValueError(
f"task {taskId} returned non-generator result but has {outputsN} outputs declared"
)
mark({"task": taskId, "action": TaskLifecycle.computed})
run_end = perf_counter_ns()
outputKey, outputSchema = outputs[0]
outputId = DatasetId(taskId, outputKey)
memory.handle(
outputId, outputSchema, result, outputId in executionContext.publish
)
mark({"task": taskId, "action": TaskLifecycle.published})
end = perf_counter_ns()

trace(Microtrace.wrk_task, end - start)
Expand Down
Loading
Loading