Skip to content

Commit 6f49609

Browse files
committed
Handling of localized expressions
1 parent 182f08a commit 6f49609

File tree

5 files changed

+32
-15
lines changed

5 files changed

+32
-15
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ dependencies = [
1515
"jsonschema~=4.0",
1616
"litellm~=1.49",
1717
"termcolor~=2.0",
18-
"ipython~=8.0"
18+
"ipython~=8.0",
19+
"strictyaml~=1.7.3"
1920
]
2021
authors = [
2122
{ name="Mandana Vaziri", email="[email protected]" },

src/pdl/pdl_ast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class LocalizedExpression(BaseModel):
6464
model_config = ConfigDict(extra="forbid", use_attribute_docstrings=True)
6565
expr: Any
6666
location: Optional[LocationType] = None
67-
_pdl_yaml_src: Optional[YamlSource] = None
67+
pdl_yaml_src: Optional[YamlSource] = None
6868

6969

7070
ExpressionType: TypeAlias = Any | LocalizedExpression
@@ -137,7 +137,7 @@ class Block(BaseModel):
137137
# Fields for internal use
138138
result: Optional[Any] = None
139139
location: Optional[LocationType] = None
140-
_pdl_yaml_src: Optional[YamlSource] = None
140+
pdl_yaml_src: Optional[YamlSource] = None
141141

142142

143143
class FunctionBlock(Block):
@@ -275,7 +275,7 @@ class ModelBlock(Block):
275275
class BamModelBlock(ModelBlock):
276276
platform: Literal[ModelPlatform.BAM]
277277
prompt_id: Optional[str] = None
278-
parameters: Optional[BamTextGenerationParameters | dict] = None
278+
parameters: Optional[BamTextGenerationParameters | ExpressionType] = None
279279
moderations: Optional[ModerationParameters] = None
280280
data: Optional[PromptTemplateData] = None
281281
constraints: Any = None # TODO
@@ -285,7 +285,7 @@ class LitellmModelBlock(ModelBlock):
285285
"""Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""
286286

287287
platform: Literal[ModelPlatform.LITELLM] = ModelPlatform.LITELLM
288-
parameters: Optional[LitellmParameters | dict] = None
288+
parameters: Optional[LitellmParameters | ExpressionType] = None
289289

290290

291291
class CodeBlock(Block):

src/pdl/pdl_compilers/to_regex.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
IncludeBlock,
2020
LitellmModelBlock,
2121
LitellmParameters,
22+
LocalizedExpression,
2223
ModelBlock,
2324
ReadBlock,
2425
RepeatBlock,
@@ -273,10 +274,14 @@ def compile_block(
273274
"include_stop_sequence", False
274275
)
275276
else:
276-
stop_sequences = block.parameters.stop_sequences or []
277+
if isinstance(block.parameters, LocalizedExpression):
278+
parameters = block.parameters.expr
279+
else:
280+
parameters = block.parameters
281+
stop_sequences = parameters.stop_sequences or []
277282
include_stop_sequence = (
278-
block.parameters.include_stop_sequence is None
279-
or block.parameters.include_stop_sequence
283+
parameters.include_stop_sequence is None
284+
or parameters.include_stop_sequence
280285
)
281286
case LitellmModelBlock():
282287
if block.parameters is None:
@@ -285,6 +290,8 @@ def compile_block(
285290
else:
286291
if isinstance(block.parameters, LitellmParameters):
287292
parameters = block.parameters.model_dump()
293+
elif isinstance(block.parameters, LocalizedExpression):
294+
parameters = block.parameters.expr
288295
else:
289296
parameters = block.parameters
290297
stop_sequences = parameters.get("stop", [])

src/pdl/pdl_interpreter.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
LastOfBlock,
4747
LitellmModelBlock,
4848
LitellmParameters,
49+
LocalizedExpression,
4950
LocationType,
5051
Message,
5152
MessageBlock,
@@ -917,8 +918,12 @@ def process_condition_of(
917918
EXPR_END_STRING = "}"
918919

919920

920-
def process_expr(scope: ScopeType, expr: Any, loc: LocationType) -> Any:
921+
def process_expr(
922+
scope: ScopeType, expr: Any, loc: LocationType
923+
) -> Any: # pylint: disable=too-many-return-statements
921924
result: Any
925+
if isinstance(expr, LocalizedExpression):
926+
return process_expr(scope, expr.expr, loc)
922927
if isinstance(expr, str):
923928
try:
924929
if expr.startswith(EXPR_START_STRING) and expr.endswith(EXPR_END_STRING):
@@ -1001,7 +1006,7 @@ def step_call_model(
10011006
],
10021007
]:
10031008
# evaluate model name
1004-
_, concrete_block = process_expr_of(block, "model", scope, loc)
1009+
model, concrete_block = process_expr_of(block, "model", scope, loc)
10051010
# evaluate model params
10061011
match concrete_block:
10071012
case BamModelBlock():
@@ -1060,9 +1065,7 @@ def get_transformed_inputs(kwargs):
10601065
if "input" in litellm_params:
10611066
append_log(state, "Model Input", litellm_params["input"])
10621067
else:
1063-
append_log(
1064-
state, "Model Input", messages_to_str(concrete_block.model, model_input)
1065-
)
1068+
append_log(state, "Model Input", messages_to_str(model, model_input))
10661069
background: Messages = [msg]
10671070
result = msg["content"]
10681071
append_log(state, "Model Output", result)
@@ -1104,6 +1107,8 @@ def generate_client_response_streaming(
11041107
model_input: Messages,
11051108
) -> Generator[YieldMessage, Any, Message]:
11061109
msg_stream: Generator[Message, Any, None]
1110+
assert isinstance(block.model, str) # block is a "concrete block"
1111+
assert isinstance(block.parameters, dict) # block is a "concrete block"
11071112
model_input_str = messages_to_str(block.model, model_input)
11081113
match block:
11091114
case BamModelBlock():
@@ -1158,6 +1163,8 @@ def generate_client_response_single(
11581163
block: BamModelBlock | LitellmModelBlock,
11591164
model_input: Messages,
11601165
) -> Generator[YieldMessage, Any, Message]:
1166+
assert isinstance(block.model, str) # block is a "concrete block"
1167+
assert isinstance(block.parameters, dict) # block is a "concrete block"
11611168
msg: Message
11621169
model_input_str = messages_to_str(block.model, model_input)
11631170
match block:
@@ -1189,6 +1196,8 @@ def generate_client_response_batching( # pylint: disable=too-many-arguments
11891196
# model: str,
11901197
model_input: Messages,
11911198
) -> Generator[YieldMessage, Any, Message]:
1199+
assert isinstance(block.model, str) # block is a "concrete block"
1200+
assert isinstance(block.parameters, dict) # block is a "concrete block"
11921201
model_input_str = messages_to_str(block.model, model_input)
11931202
match block:
11941203
case BamModelBlock():

src/pdl/pdl_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def set_location(
4949
pdl: Any,
5050
loc: YamlSource,
5151
):
52-
if hasattr(pdl, "_pdl_yaml_src"):
53-
pdl._pdl_yaml_src = loc
52+
if hasattr(pdl, "pdl_yaml_src"):
53+
pdl.pdl_yaml_src = loc
5454
if isinstance(loc.data, dict):
5555
for x, v in loc.items():
5656
if hasattr(pdl, x.data):

0 commit comments

Comments
 (0)