Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pdl-live-react/src/pdl_ast.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4055,6 +4055,7 @@ export interface ContributeValue {
}
export interface LocalizedExpression {
expr: Expr
pdl__result?: unknown
pdl__location?: PdlLocationType | null
}
export interface Expr {
Expand Down
10 changes: 10 additions & 0 deletions src/pdl/pdl-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -7294,6 +7294,16 @@
"expr": {
"title": "Expr"
},
"pdl__result": {
"anyOf": [
{},
{
"type": "null"
}
],
"default": null,
"title": "Pdl Result"
},
"pdl__location": {
"anyOf": [
{
Expand Down
3 changes: 2 additions & 1 deletion src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class LocalizedExpression(BaseModel, Generic[LocalizedExpressionT]):
arbitrary_types_allowed=True,
model_title_generator=(lambda _: "LocalizedExpression"),
)
expr: LocalizedExpressionT
expr: Any
pdl__result: Optional[LocalizedExpressionT] = None
pdl__location: Optional[PdlLocationType] = None


Expand Down
61 changes: 39 additions & 22 deletions src/pdl/pdl_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DataBlock,
EmptyBlock,
ErrorBlock,
ExpressionType,
FunctionBlock,
GetBlock,
GraniteioModelBlock,
Expand All @@ -29,6 +30,7 @@
LastOfBlock,
LitellmModelBlock,
LitellmParameters,
LocalizedExpression,
MatchBlock,
MessageBlock,
ObjectBlock,
Expand Down Expand Up @@ -112,29 +114,28 @@ def block_to_dict( # noqa: C901
match block:
case LitellmModelBlock():
d["platform"] = str(block.platform)
d["model"] = block.model
if block.input is not None:
d["input"] = block_to_dict(block.input, json_compatible)
d["model"] = expr_to_dict(block.model, json_compatible)
d["input"] = block_to_dict(block.input, json_compatible)
if block.parameters is not None:
if isinstance(block.parameters, LitellmParameters):
d["parameters"] = block.parameters.model_dump(
exclude_unset=True, exclude_defaults=True
)
else:
d["parameters"] = block.parameters
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
if block.modelResponse is not None:
d["modelResponse"] = block.modelResponse
if block.pdl__usage is not None:
d["pdl__usage"] = usage_to_dict(block.pdl__usage)
case GraniteioModelBlock():
d["model"] = block.model
d["model"] = expr_to_dict(block.model, json_compatible)
d["platform"] = str(block.platform)
d["backend"] = block.backend
d["processor"] = block.processor
if block.input is not None:
d["input"] = block_to_dict(block.input, json_compatible)
d["backend"] = expr_to_dict(block.backend, json_compatible)
if block.processor is not None:
d["processor"] = expr_to_dict(block.processor, json_compatible)
d["input"] = block_to_dict(block.input, json_compatible)
if block.parameters is not None:
d["parameters"] = block.parameters
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
if block.modelResponse is not None:
d["modelResponse"] = block.modelResponse
if block.pdl__usage is not None:
Expand All @@ -147,7 +148,7 @@ def block_to_dict( # noqa: C901
case GetBlock():
d["get"] = block.get
case DataBlock():
d["data"] = data_to_dict(block.data, json_compatible)
d["data"] = expr_to_dict(block.data, json_compatible)
if block.raw:
d["raw"] = block.raw
case TextBlock():
Expand All @@ -171,7 +172,7 @@ def block_to_dict( # noqa: C901
case MessageBlock():
d["content"] = block_to_dict(block.content, json_compatible)
case ReadBlock():
d["read"] = block.read
d["read"] = expr_to_dict(block.read, json_compatible)
d["message"] = block.message
d["multiline"] = block.multiline
case IncludeBlock():
Expand All @@ -183,18 +184,18 @@ def block_to_dict( # noqa: C901
if block.pdl__trace:
d["pdl__trace"] = block_to_dict(block.pdl__trace, json_compatible)
case IfBlock():
d["if"] = block.condition
d["if"] = expr_to_dict(block.condition, json_compatible)
d["then"] = block_to_dict(block.then, json_compatible)
if block.else_ is not None:
d["else"] = block_to_dict(block.else_, json_compatible)
if block.if_result is not None:
d["if_result"] = block.if_result
case MatchBlock():
d["match"] = block.match_
d["match"] = expr_to_dict(block.match_, json_compatible)
d["with"] = [
{
"case": pattern_to_dict(match_case.case),
"if": match_case.if_,
"if": expr_to_dict(match_case.if_, json_compatible),
"then": block_to_dict(match_case.then, json_compatible),
"pdl__case_result": match_case.pdl__case_result,
"pdl__if_result": match_case.pdl__if_result,
Expand All @@ -203,11 +204,17 @@ def block_to_dict( # noqa: C901
for match_case in block.with_
]
case RepeatBlock():
d["for"] = block.for_
d["while"] = block.while_
if block.for_ is not None:
d["for"] = expr_to_dict(block.for_, json_compatible)
if block.while_ is not None:
d["while"] = expr_to_dict(block.while_, json_compatible)
d["repeat"] = block_to_dict(block.repeat, json_compatible)
d["until"] = block.until
d["max_iterations"] = block.max_iterations
if block.until is not None:
d["until"] = expr_to_dict(block.until, json_compatible)
if block.max_iterations is not None:
d["max_iterations"] = expr_to_dict(
block.max_iterations, json_compatible
)
d["join"] = join_to_dict(block.join)
if block.pdl__trace is not None:
d["pdl__trace"] = [
Expand All @@ -219,8 +226,8 @@ def block_to_dict( # noqa: C901
# if block.scope is not None:
# d["scope"] = scope_to_dict(block.scope, json_compatible)
case CallBlock():
d["call"] = block.call
d["args"] = data_to_dict(block.args, json_compatible)
d["call"] = expr_to_dict(block.call, json_compatible)
d["args"] = expr_to_dict(block.args, json_compatible)
if block.pdl__trace is not None:
d["pdl__trace"] = block_to_dict(
block.pdl__trace, json_compatible
Expand Down Expand Up @@ -257,14 +264,24 @@ def block_to_dict( # noqa: C901
return d


def data_to_dict(data: Any, json_compatible):
def data_to_dict(data: Any, json_compatible: bool):
if json_compatible:
d = as_json(data)
else:
d = data
return d


def expr_to_dict(expr: ExpressionType, json_compatible: bool):
if isinstance(expr, LocalizedExpression):
d = {"expr": data_to_dict(expr.expr, json_compatible)}
if expr.pdl__result is not None:
d["pdl__result"] = data_to_dict(expr.pdl__result, json_compatible)
else:
d = data_to_dict(expr, json_compatible)
return d


def timing_to_dict(timing: PdlTiming) -> dict:
d: dict = {}
if timing.start_nanos != 0:
Expand Down
40 changes: 24 additions & 16 deletions src/pdl/pdl_granite_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@
)
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
from .pdl_llms import _LOOP
from .pdl_utils import value_of_expr


class GraniteioModel:
@staticmethod
def processor_of_block(block: GraniteioModelBlock):
model = value_of_expr(block.model)
backend = value_of_expr(block.backend)
assert isinstance(model, str), f"The model should be a string: {model}"
assert isinstance(
block.model, str
), f"The model should be a string: {block.model}"
assert isinstance(
block.backend, (dict, str)
), f"The backend should be a string or a dictionnary: {block.backend}"
match block.backend:
backend, (dict, str)
), f"The backend should be a string or a dictionnary: {backend}"
match backend:
case {"transformers": device}:
assert isinstance(block.backend, dict)
assert isinstance(backend, dict)
from granite_io import make_backend

backend = make_backend(
"transformers",
{
"model_name": block.model,
"model_name": model,
"device": device,
},
)
Expand All @@ -42,14 +43,15 @@ def processor_of_block(block: GraniteioModelBlock):
backend = make_backend(
backend_name,
{
"model_name": block.model,
"model_name": model,
},
)
case _:
assert False, f"Unexpected backend: {block.backend}"
processor_name = block.processor
if processor_name is None:
processor_name = block.model
assert False, f"Unexpected backend: {backend}"
if block.processor is None:
processor_name = model
else:
processor_name = value_of_expr(block.processor)
assert isinstance(
processor_name, str
), f"The processor should be a string: {processor_name}"
Expand All @@ -73,10 +75,14 @@ async def async_generate_text(
block: GraniteioModelBlock,
messages: ModelInput,
) -> tuple[dict[str, Any], Any]:
if block.parameters is None:
parameters = None
else:
parameters = value_of_expr(block.parameters)
try:
assert block.parameters is None or isinstance(block.parameters, dict)
assert parameters is None or isinstance(parameters, dict)
io_processor = GraniteioModel.processor_of_block(block)
inputs = GraniteioModel.build_message(messages, block.parameters)
inputs = GraniteioModel.build_message(messages, parameters)
result = io_processor.create_chat_completion(inputs) # pyright: ignore
try: # TODO: update when new version of granite-io is released
message = result.next_message.model_dump()
Expand All @@ -88,7 +94,9 @@ async def async_generate_text(
raw_result,
)
except Exception as exc:
message = f"Error during '{block.model}' model call: {repr(exc)}"
message = (
f"Error during '{value_of_expr(block.model)}' model call: {repr(exc)}"
)
loc = block.pdl__location
raise PDLRuntimeError(
message,
Expand Down
Loading