Skip to content

Commit ccd4220

Browse files
committed
Systematically used localized expressions in the trace and add pdl__result in them
1 parent 33e1907 commit ccd4220

File tree

8 files changed

+169
-74
lines changed

8 files changed

+169
-74
lines changed

pdl-live-react/src/pdl_ast.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,6 +3823,7 @@ export interface ContributeValue {
38233823
}
38243824
export interface LocalizedExpression {
38253825
expr: Expr
3826+
pdl__result?: unknown
38263827
pdl__location?: PdlLocationType | null
38273828
}
38283829
export interface Expr {

src/pdl/pdl-schema.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6775,6 +6775,16 @@
67756775
"expr": {
67766776
"title": "Expr"
67776777
},
6778+
"pdl__result": {
6779+
"anyOf": [
6780+
{},
6781+
{
6782+
"type": "null"
6783+
}
6784+
],
6785+
"default": null,
6786+
"title": "Pdl Result"
6787+
},
67786788
"pdl__location": {
67796789
"anyOf": [
67806790
{

src/pdl/pdl_ast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class LocalizedExpression(BaseModel, Generic[LocalizedExpressionT]):
8484
arbitrary_types_allowed=True,
8585
model_title_generator=(lambda _: "LocalizedExpression"),
8686
)
87-
expr: LocalizedExpressionT
87+
expr: Any
88+
pdl__result: Optional[LocalizedExpressionT] = None
8889
pdl__location: Optional[PdlLocationType] = None
8990

9091

src/pdl/pdl_dumper.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DataBlock,
1818
EmptyBlock,
1919
ErrorBlock,
20+
ExpressionType,
2021
FunctionBlock,
2122
GetBlock,
2223
GraniteioModelBlock,
@@ -28,6 +29,7 @@
2829
LastOfBlock,
2930
LitellmModelBlock,
3031
LitellmParameters,
32+
LocalizedExpression,
3133
MatchBlock,
3234
MessageBlock,
3335
ObjectBlock,
@@ -110,27 +112,26 @@ def block_to_dict( # noqa: C901
110112
match block:
111113
case LitellmModelBlock():
112114
d["platform"] = str(block.platform)
113-
d["model"] = block.model
114-
if block.input is not None:
115-
d["input"] = block_to_dict(block.input, json_compatible)
115+
d["model"] = expr_to_dict(block.model, json_compatible)
116+
d["input"] = block_to_dict(block.input, json_compatible)
116117
if block.parameters is not None:
117118
if isinstance(block.parameters, LitellmParameters):
118119
d["parameters"] = block.parameters.model_dump(
119120
exclude_unset=True, exclude_defaults=True
120121
)
121122
else:
122-
d["parameters"] = block.parameters
123+
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
123124
if block.modelResponse is not None:
124125
d["modelResponse"] = block.modelResponse
125126
case GraniteioModelBlock():
126-
d["model"] = block.model
127+
d["model"] = expr_to_dict(block.model, json_compatible)
127128
d["platform"] = str(block.platform)
128-
d["backend"] = block.backend
129-
d["processor"] = block.processor
130-
if block.input is not None:
131-
d["input"] = block_to_dict(block.input, json_compatible)
129+
d["backend"] = expr_to_dict(block.backend, json_compatible)
130+
if block.processor is not None:
131+
d["processor"] = expr_to_dict(block.processor, json_compatible)
132+
d["input"] = block_to_dict(block.input, json_compatible)
132133
if block.parameters is not None:
133-
d["parameters"] = block.parameters
134+
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
134135
if block.modelResponse is not None:
135136
d["modelResponse"] = block.modelResponse
136137
case CodeBlock():
@@ -139,7 +140,7 @@ def block_to_dict( # noqa: C901
139140
case GetBlock():
140141
d["get"] = block.get
141142
case DataBlock():
142-
d["data"] = data_to_dict(block.data, json_compatible)
143+
d["data"] = expr_to_dict(block.data, json_compatible)
143144
if block.raw:
144145
d["raw"] = block.raw
145146
case TextBlock():
@@ -163,7 +164,7 @@ def block_to_dict( # noqa: C901
163164
case MessageBlock():
164165
d["content"] = block_to_dict(block.content, json_compatible)
165166
case ReadBlock():
166-
d["read"] = block.read
167+
d["read"] = expr_to_dict(block.read, json_compatible)
167168
d["message"] = block.message
168169
d["multiline"] = block.multiline
169170
case IncludeBlock():
@@ -175,18 +176,18 @@ def block_to_dict( # noqa: C901
175176
if block.pdl__trace:
176177
d["pdl__trace"] = block_to_dict(block.pdl__trace, json_compatible)
177178
case IfBlock():
178-
d["if"] = block.condition
179+
d["if"] = expr_to_dict(block.condition, json_compatible)
179180
d["then"] = block_to_dict(block.then, json_compatible)
180181
if block.else_ is not None:
181182
d["else"] = block_to_dict(block.else_, json_compatible)
182183
if block.if_result is not None:
183184
d["if_result"] = block.if_result
184185
case MatchBlock():
185-
d["match"] = block.match_
186+
d["match"] = expr_to_dict(block.match_, json_compatible)
186187
d["with"] = [
187188
{
188189
"case": pattern_to_dict(match_case.case),
189-
"if": match_case.if_,
190+
"if": expr_to_dict(match_case.if_, json_compatible),
190191
"then": block_to_dict(match_case.then, json_compatible),
191192
"pdl__case_result": match_case.pdl__case_result,
192193
"pdl__if_result": match_case.pdl__if_result,
@@ -195,11 +196,17 @@ def block_to_dict( # noqa: C901
195196
for match_case in block.with_
196197
]
197198
case RepeatBlock():
198-
d["for"] = block.for_
199-
d["while"] = block.while_
199+
if block.for_ is not None:
200+
d["for"] = expr_to_dict(block.for_, json_compatible)
201+
if block.while_ is not None:
202+
d["while"] = expr_to_dict(block.while_, json_compatible)
200203
d["repeat"] = block_to_dict(block.repeat, json_compatible)
201-
d["until"] = block.until
202-
d["max_iterations"] = block.max_iterations
204+
if block.until is not None:
205+
d["until"] = expr_to_dict(block.until, json_compatible)
206+
if block.max_iterations is not None:
207+
d["max_iterations"] = expr_to_dict(
208+
block.max_iterations, json_compatible
209+
)
203210
d["join"] = join_to_dict(block.join)
204211
if block.pdl__trace is not None:
205212
d["pdl__trace"] = [
@@ -211,8 +218,8 @@ def block_to_dict( # noqa: C901
211218
# if block.scope is not None:
212219
# d["scope"] = scope_to_dict(block.scope, json_compatible)
213220
case CallBlock():
214-
d["call"] = block.call
215-
d["args"] = data_to_dict(block.args, json_compatible)
221+
d["call"] = expr_to_dict(block.call, json_compatible)
222+
d["args"] = expr_to_dict(block.args, json_compatible)
216223
if block.pdl__trace is not None:
217224
d["pdl__trace"] = block_to_dict(
218225
block.pdl__trace, json_compatible
@@ -249,14 +256,24 @@ def block_to_dict( # noqa: C901
249256
return d
250257

251258

252-
def data_to_dict(data: Any, json_compatible):
259+
def data_to_dict(data: Any, json_compatible: bool):
253260
if json_compatible:
254261
d = as_json(data)
255262
else:
256263
d = data
257264
return d
258265

259266

267+
def expr_to_dict(expr: ExpressionType, json_compatible: bool):
268+
if isinstance(expr, LocalizedExpression):
269+
d = {"expr": data_to_dict(expr.expr, json_compatible)}
270+
if expr.pdl__result is not None:
271+
d["pdl__result"] = data_to_dict(expr.pdl__result, json_compatible)
272+
else:
273+
d = data_to_dict(expr, json_compatible)
274+
return d
275+
276+
260277
def timing_to_dict(timing: PdlTiming) -> dict:
261278
d: dict = {}
262279
if timing.start_nanos != 0:

src/pdl/pdl_granite_io.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,27 @@
1313
)
1414
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
1515
from .pdl_llms import _LOOP
16+
from .pdl_utils import value_of_expr
1617

1718

1819
class GraniteioModel:
1920
@staticmethod
2021
def processor_of_block(block: GraniteioModelBlock):
22+
model = value_of_expr(block.model)
23+
backend = value_of_expr(block.backend)
24+
assert isinstance(model, str), f"The model should be a string: {model}"
2125
assert isinstance(
22-
block.model, str
23-
), f"The model should be a string: {block.model}"
24-
assert isinstance(
25-
block.backend, (dict, str)
26-
), f"The backend should be a string or a dictionnary: {block.backend}"
27-
match block.backend:
26+
backend, (dict, str)
27+
), f"The backend should be a string or a dictionnary: {backend}"
28+
match backend:
2829
case {"transformers": device}:
29-
assert isinstance(block.backend, dict)
30+
assert isinstance(backend, dict)
3031
from granite_io import make_backend
3132

3233
backend = make_backend(
3334
"transformers",
3435
{
35-
"model_name": block.model,
36+
"model_name": model,
3637
"device": device,
3738
},
3839
)
@@ -42,14 +43,15 @@ def processor_of_block(block: GraniteioModelBlock):
4243
backend = make_backend(
4344
backend_name,
4445
{
45-
"model_name": block.model,
46+
"model_name": model,
4647
},
4748
)
4849
case _:
49-
assert False, f"Unexpected backend: {block.backend}"
50-
processor_name = block.processor
51-
if processor_name is None:
52-
processor_name = block.model
50+
assert False, f"Unexpected backend: {backend}"
51+
if block.processor is None:
52+
processor_name = model
53+
else:
54+
processor_name = value_of_expr(block.processor)
5355
assert isinstance(
5456
processor_name, str
5557
), f"The processor should be a string: {processor_name}"
@@ -73,10 +75,14 @@ async def async_generate_text(
7375
block: GraniteioModelBlock,
7476
messages: ModelInput,
7577
) -> tuple[dict[str, Any], Any]:
78+
if block.parameters is None:
79+
parameters = None
80+
else:
81+
parameters = value_of_expr(block.parameters)
7682
try:
77-
assert block.parameters is None or isinstance(block.parameters, dict)
83+
assert parameters is None or isinstance(parameters, dict)
7884
io_processor = GraniteioModel.processor_of_block(block)
79-
inputs = GraniteioModel.build_message(messages, block.parameters)
85+
inputs = GraniteioModel.build_message(messages, parameters)
8086
result = io_processor.create_chat_completion(inputs) # pyright: ignore
8187
try: # TODO: update when new version of granite-io is released
8288
message = result.next_message.model_dump()
@@ -88,7 +94,9 @@ async def async_generate_text(
8894
raw_result,
8995
)
9096
except Exception as exc:
91-
message = f"Error during '{block.model}' model call: {repr(exc)}"
97+
message = (
98+
f"Error during '{value_of_expr(block.model)}' model call: {repr(exc)}"
99+
)
92100
loc = block.pdl__location
93101
raise PDLRuntimeError(
94102
message,

0 commit comments

Comments
 (0)