Skip to content

Commit a8c9f97

Browse files
committed
feat: replay an execution
Signed-off-by: Louis Mandel <[email protected]>
1 parent 0398cf4 commit a8c9f97

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/pdl/pdl.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class InterpreterConfig(TypedDict, total=False):
4747
"""
4848
cwd: Path
4949
"""Path considered as the current working directory for file reading."""
50+
replay: dict[str, Any]
51+
"""Execute the program reusing some already computed values.
52+
"""
5053

5154

5255
def exec_program(
@@ -66,9 +69,10 @@ def exec_program(
6669
output: Configure the output of the returned value of this function. Defaults to `"result"`
6770
6871
Returns:
69-
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, and `trace`.
72+
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, `trace`, and `replay`.
7073
"""
71-
config = config or {}
74+
config = config or InterpreterConfig()
75+
config["replay"] = dict(config.get("replay", {}))
7276
state = InterpreterState(**config)
7377
if not isinstance(scope, PdlDict):
7478
scope = PdlDict(scope or {})
@@ -83,7 +87,12 @@ def exec_program(
8387
return result
8488
case "all":
8589
scope = future_scope.result()
86-
return {"result": result, "scope": scope, "trace": trace}
90+
return {
91+
"result": result,
92+
"scope": scope,
93+
"trace": trace,
94+
"replay": state.replay,
95+
}
8796
case _:
8897
assert False, 'The `output` variable should be "result" or "all"'
8998

src/pdl/pdl_interpreter.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@
152152
write_trace,
153153
)
154154

155-
empty_scope: ScopeType = PdlDict(
156-
{"pdl_context": DependentContext([]), "__pdl_replay": {}}
157-
)
155+
empty_scope: ScopeType = PdlDict({"pdl_context": DependentContext([])})
158156

159157

160158
RefT = TypeVar("RefT")
@@ -190,6 +188,7 @@ class InterpreterState(BaseModel):
190188
"""Event loop to schedule LLM calls."""
191189
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
192190
"""Current value of the context set at the beginning of the execution of the block."""
191+
replay: dict[str, Any] = {}
193192

194193
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
195194
return self.model_copy(update={"yield_result": b})
@@ -498,7 +497,7 @@ def process_advance_block_retry(
498497
trial_total = max_retry + 1
499498
for trial_idx in range(trial_total):
500499
try:
501-
result, background, new_scope, trace = process_block_body(
500+
result, background, new_scope, trace = process_block_body_with_replay(
502501
state, scope, block, loc
503502
)
504503
if block.requirements != []:
@@ -633,11 +632,9 @@ def process_block_body_with_replay(
633632
) -> tuple[PdlLazy[Any], LazyMessages, ScopeType, AdvancedBlockType]:
634633
if isinstance(block, LeafBlock):
635634
block_id = block.pdl__id
636-
replay_scope = scope["__pdl_replay"]
637635
assert isinstance(block_id, str)
638-
assert isinstance(replay_scope, dict)
639636
try:
640-
result = replay_scope[block_id]
637+
result = state.replay[block_id]
641638
background: LazyMessages = SingletonContext(
642639
PdlDict({"role": state.role, "content": result})
643640
)
@@ -650,7 +647,7 @@ def process_block_body_with_replay(
650647
result, background, scope, trace = process_block_body(
651648
state, scope, block, loc
652649
)
653-
scope = scope | {"__pdl_replay": (replay_scope | {block_id: result})}
650+
state.replay[block_id] = result
654651
else:
655652
result, background, scope, trace = process_block_body(state, scope, block, loc)
656653
return result, background, scope, trace

0 commit comments

Comments
 (0)