diff --git a/pdl-live-react/src/pdl_ast.d.ts b/pdl-live-react/src/pdl_ast.d.ts index d118f4b38..d7272dfe6 100644 --- a/pdl-live-react/src/pdl_ast.d.ts +++ b/pdl-live-react/src/pdl_ast.d.ts @@ -4055,6 +4055,7 @@ export interface ContributeValue { } export interface LocalizedExpression { expr: Expr + pdl__result?: unknown pdl__location?: PdlLocationType | null } export interface Expr { diff --git a/src/pdl/pdl-schema.json b/src/pdl/pdl-schema.json index db412f138..ec33ce6d9 100644 --- a/src/pdl/pdl-schema.json +++ b/src/pdl/pdl-schema.json @@ -7294,6 +7294,16 @@ "expr": { "title": "Expr" }, + "pdl__result": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "default": null, + "title": "Pdl Result" + }, "pdl__location": { "anyOf": [ { diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index a01b2ea0f..538e16526 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -84,7 +84,8 @@ class LocalizedExpression(BaseModel, Generic[LocalizedExpressionT]): arbitrary_types_allowed=True, model_title_generator=(lambda _: "LocalizedExpression"), ) - expr: LocalizedExpressionT + expr: Any + pdl__result: Optional[LocalizedExpressionT] = None pdl__location: Optional[PdlLocationType] = None diff --git a/src/pdl/pdl_dumper.py b/src/pdl/pdl_dumper.py index 1777ba72b..8918204e0 100644 --- a/src/pdl/pdl_dumper.py +++ b/src/pdl/pdl_dumper.py @@ -18,6 +18,7 @@ DataBlock, EmptyBlock, ErrorBlock, + ExpressionType, FunctionBlock, GetBlock, GraniteioModelBlock, @@ -29,6 +30,7 @@ LastOfBlock, LitellmModelBlock, LitellmParameters, + LocalizedExpression, MatchBlock, MessageBlock, ObjectBlock, @@ -112,29 +114,28 @@ def block_to_dict( # noqa: C901 match block: case LitellmModelBlock(): d["platform"] = str(block.platform) - d["model"] = block.model - if block.input is not None: - d["input"] = block_to_dict(block.input, json_compatible) + d["model"] = expr_to_dict(block.model, json_compatible) + d["input"] = block_to_dict(block.input, json_compatible) if block.parameters is not None: if isinstance(block.parameters, LitellmParameters): d["parameters"] = block.parameters.model_dump( exclude_unset=True, exclude_defaults=True ) else: - d["parameters"] = block.parameters + d["parameters"] = expr_to_dict(block.parameters, json_compatible) if block.modelResponse is not None: d["modelResponse"] = block.modelResponse if block.pdl__usage is not None: d["pdl__usage"] = usage_to_dict(block.pdl__usage) case GraniteioModelBlock(): - d["model"] = block.model + d["model"] = expr_to_dict(block.model, json_compatible) d["platform"] = str(block.platform) - d["backend"] = block.backend - d["processor"] = block.processor - if block.input is not None: - d["input"] = block_to_dict(block.input, json_compatible) + d["backend"] = expr_to_dict(block.backend, json_compatible) + if block.processor is not None: + d["processor"] = expr_to_dict(block.processor, json_compatible) + d["input"] = block_to_dict(block.input, json_compatible) if block.parameters is not None: - d["parameters"] = block.parameters + d["parameters"] = expr_to_dict(block.parameters, json_compatible) if block.modelResponse is not None: d["modelResponse"] = block.modelResponse if block.pdl__usage is not None: @@ -147,7 +148,7 @@ def block_to_dict( # noqa: C901 case GetBlock(): d["get"] = block.get case DataBlock(): - d["data"] = data_to_dict(block.data, json_compatible) + d["data"] = expr_to_dict(block.data, json_compatible) if block.raw: d["raw"] = block.raw case TextBlock(): @@ -171,7 +172,7 @@ def block_to_dict( # noqa: C901 case MessageBlock(): d["content"] = block_to_dict(block.content, json_compatible) case ReadBlock(): - d["read"] = block.read + d["read"] = expr_to_dict(block.read, json_compatible) d["message"] = block.message d["multiline"] = block.multiline case IncludeBlock(): @@ -183,18 +184,18 @@ def block_to_dict( # noqa: C901 if block.pdl__trace: d["pdl__trace"] = block_to_dict(block.pdl__trace, json_compatible) case IfBlock(): - d["if"] = block.condition + d["if"] = expr_to_dict(block.condition, json_compatible) d["then"] = block_to_dict(block.then, json_compatible) if block.else_ is not None: d["else"] = block_to_dict(block.else_, json_compatible) if block.if_result is not None: d["if_result"] = block.if_result case MatchBlock(): - d["match"] = block.match_ + d["match"] = expr_to_dict(block.match_, json_compatible) d["with"] = [ { "case": pattern_to_dict(match_case.case), - "if": match_case.if_, + "if": expr_to_dict(match_case.if_, json_compatible), "then": block_to_dict(match_case.then, json_compatible), "pdl__case_result": match_case.pdl__case_result, "pdl__if_result": match_case.pdl__if_result, @@ -203,11 +204,17 @@ def block_to_dict( # noqa: C901 for match_case in block.with_ ] case RepeatBlock(): - d["for"] = block.for_ - d["while"] = block.while_ + if block.for_ is not None: + d["for"] = expr_to_dict(block.for_, json_compatible) + if block.while_ is not None: + d["while"] = expr_to_dict(block.while_, json_compatible) d["repeat"] = block_to_dict(block.repeat, json_compatible) - d["until"] = block.until - d["max_iterations"] = block.max_iterations + if block.until is not None: + d["until"] = expr_to_dict(block.until, json_compatible) + if block.max_iterations is not None: + d["max_iterations"] = expr_to_dict( + block.max_iterations, json_compatible + ) d["join"] = join_to_dict(block.join) if block.pdl__trace is not None: d["pdl__trace"] = [ @@ -219,8 +226,8 @@ def block_to_dict( # noqa: C901 # if block.scope is not None: # d["scope"] = scope_to_dict(block.scope, json_compatible) case CallBlock(): - d["call"] = block.call - d["args"] = data_to_dict(block.args, json_compatible) + d["call"] = expr_to_dict(block.call, json_compatible) + d["args"] = expr_to_dict(block.args, json_compatible) if block.pdl__trace is not None: d["pdl__trace"] = block_to_dict( block.pdl__trace, json_compatible @@ -257,7 +264,7 @@ def block_to_dict( # noqa: C901 return d -def data_to_dict(data: Any, json_compatible): +def data_to_dict(data: Any, json_compatible: bool): if json_compatible: d = as_json(data) else: @@ -265,6 +272,16 @@ def data_to_dict(data: Any, json_compatible): return d +def expr_to_dict(expr: ExpressionType, json_compatible: bool): + if isinstance(expr, LocalizedExpression): + d = {"expr": data_to_dict(expr.expr, json_compatible)} + if expr.pdl__result is not None: + d["pdl__result"] = data_to_dict(expr.pdl__result, json_compatible) + else: + d = data_to_dict(expr, json_compatible) + return d + + def timing_to_dict(timing: PdlTiming) -> dict: d: dict = {} if timing.start_nanos != 0: diff --git a/src/pdl/pdl_granite_io.py b/src/pdl/pdl_granite_io.py index aa96336d7..6dd739c2d 100644 --- a/src/pdl/pdl_granite_io.py +++ b/src/pdl/pdl_granite_io.py @@ -13,26 +13,27 @@ ) from .pdl_lazy import PdlConst, PdlLazy, lazy_apply from .pdl_llms import _LOOP +from .pdl_utils import value_of_expr class GraniteioModel: @staticmethod def processor_of_block(block: GraniteioModelBlock): + model = value_of_expr(block.model) + backend = value_of_expr(block.backend) + assert isinstance(model, str), f"The model should be a string: {model}" assert isinstance( - block.model, str - ), f"The model should be a string: {block.model}" - assert isinstance( - block.backend, (dict, str) - ), f"The backend should be a string or a dictionnary: {block.backend}" - match block.backend: + backend, (dict, str) + ), f"The backend should be a string or a dictionnary: {backend}" + match backend: case {"transformers": device}: - assert isinstance(block.backend, dict) + assert isinstance(backend, dict) from granite_io import make_backend backend = make_backend( "transformers", { - "model_name": block.model, + "model_name": model, "device": device, }, ) @@ -42,14 +43,15 @@ def processor_of_block(block: GraniteioModelBlock): backend = make_backend( backend_name, { - "model_name": block.model, + "model_name": model, }, ) case _: - assert False, f"Unexpected backend: {block.backend}" - processor_name = block.processor - if processor_name is None: - processor_name = block.model + assert False, f"Unexpected backend: {backend}" + if block.processor is None: + processor_name = model + else: + processor_name = value_of_expr(block.processor) assert isinstance( processor_name, str ), f"The processor should be a string: {processor_name}" @@ -73,10 +75,14 @@ async def async_generate_text( block: GraniteioModelBlock, messages: ModelInput, ) -> tuple[dict[str, Any], Any]: + if block.parameters is None: + parameters = None + else: + parameters = value_of_expr(block.parameters) try: - assert block.parameters is None or isinstance(block.parameters, dict) + assert parameters is None or isinstance(parameters, dict) io_processor = GraniteioModel.processor_of_block(block) - inputs = GraniteioModel.build_message(messages, block.parameters) + inputs = GraniteioModel.build_message(messages, parameters) result = io_processor.create_chat_completion(inputs) # pyright: ignore try: # TODO: update when new version of granite-io is released message = result.next_message.model_dump() @@ -88,7 +94,9 @@ async def async_generate_text( raw_result, ) except Exception as exc: - message = f"Error during '{block.model}' model call: {repr(exc)}" + message = ( + f"Error during '{value_of_expr(block.model)}' model call: {repr(exc)}" + ) loc = block.pdl__location raise PDLRuntimeError( message, diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 2f4df64ba..6f8db5648 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -104,6 +104,7 @@ lazy_messages_concat, replace_contribute_value, stringify, + value_of_expr, ) empty_scope: ScopeType = PdlDict({"pdl_context": PdlList([])}) @@ -237,13 +238,14 @@ def process_block( if not isinstance(block, Block): start = time.time_ns() try: - result = PdlConst(process_expr(scope, block, loc)) + v, expr = process_expr(scope, block, loc) except PDLRuntimeExpressionError as exc: raise PDLRuntimeError( exc.message, loc=exc.loc or loc, trace=ErrorBlock(msg=exc.message, pdl__location=loc, program=block), ) from exc + result = PdlConst(v) stringified_result = lazy_apply(stringify, result) background = PdlList( [ @@ -259,7 +261,7 @@ def process_block( ] ) trace = DataBlock( - data=block, + data=expr, pdl__result=stringified_result, pdl__timing=PdlTiming(start_nanos=start, end_nanos=time.time_ns()), pdl__id=".".join(state.id_stack), @@ -573,7 +575,7 @@ def process_block_body( {"role": state.role, "content": content, "defsite": block.pdl__id} ) case IfBlock(): - b = process_condition_of(block, "condition", scope, loc, "if") + b, if_trace = process_condition_of(block, "condition", scope, loc, "if") if b: state = state.with_iter(0) result, background, scope, trace = process_block_of( @@ -592,6 +594,7 @@ def process_block_body( trace = block trace = trace.model_copy( update={ + "condition": if_trace, "if_result": b, } ) @@ -623,7 +626,8 @@ def process_block_body( if "if_" in match_case.model_fields_set and match_case.if_ is not None: loc_if = append(loc_i, "if") try: - b = process_expr(new_scope, match_case.if_, loc_if) + b, if_trace = process_expr(new_scope, match_case.if_, loc_if) + match_case = match_case.model_copy(update={"if_": if_trace}) except PDLRuntimeExpressionError as exc: cases.append(match_case) block.with_ = cases @@ -718,7 +722,7 @@ def process_block_body( break if lengths is not None and iidx >= lengths[0]: break - stay = process_condition_of(block, "while_", scope, loc, "while") + stay, _ = process_condition_of(block, "while_", scope, loc, "while") if not stay: break iteration_state = iteration_state.with_iter(iidx) @@ -763,7 +767,7 @@ def process_block_body( iter_trace.append(body_trace) iteration_state = iteration_state.with_pop() iidx = iidx + 1 - stop = process_condition_of(block, "until", scope, loc) + stop, _ = process_condition_of(block, "until", scope, loc) if stop: break except PDLRuntimeError as exc: @@ -1043,17 +1047,25 @@ def combine_results(iteration_type: IterationType, results: list[PdlLazy[Any]]): def process_contribute( block: BlockTypeTVarProcessContribute, scope: ScopeType, loc: PdlLocationType ) -> tuple[Any, BlockTypeTVarProcessContribute]: + result: list[ContributeTarget | dict[str, ContributeValue]] + value_trace: LocalizedExpression[ + list[ContributeTarget | dict[str, ContributeValue]] + ] value = get_contribute_value(block.contribute) + if value is None: + return None, block loc = append(loc, "contribute") try: - result: ContributeValue = process_expr(scope, value, loc) # pyright: ignore + result, value_trace = process_expr(scope, value, loc) except PDLRuntimeExpressionError as exc: raise PDLRuntimeError( exc.message, loc=exc.loc or loc, trace=ErrorBlock(msg=exc.message, pdl__location=loc, program=block), ) from exc - replace = replace_contribute_value(block.contribute, result) + replace = replace_contribute_value( + block.contribute, ContributeValue(value=value_trace) + ) trace = block.model_copy(update={"contribute": replace}) return result, trace @@ -1070,17 +1082,19 @@ def process_expr_of( loc: PdlLocationType, field_alias: Optional[str] = None, ) -> tuple[Any, BlockTypeTVarProcessExprOf]: + result: Any + expr_trace: LocalizedExpression[Any] expr = getattr(block, field) loc = append(loc, field_alias or field) try: - result: Any = process_expr(scope, expr, loc) + result, expr_trace = process_expr(scope, expr, loc) except PDLRuntimeExpressionError as exc: raise PDLRuntimeError( exc.message, loc=exc.loc or loc, trace=ErrorBlock(msg=exc.message, pdl__location=loc, program=block), ) from exc - trace = block.model_copy(update={field: result}) + trace = block.model_copy(update={field: expr_trace}) return result, trace @@ -1090,18 +1104,20 @@ def process_condition_of( scope: ScopeType, loc: PdlLocationType, field_alias: Optional[str] = None, -) -> bool: +) -> tuple[bool, LocalizedExpression[bool]]: + result: bool + expr_trace: LocalizedExpression[bool] expr = getattr(block, field) loc = append(loc, field_alias or field) try: - result: bool = process_expr(scope, expr, loc) + result, expr_trace = process_expr(scope, expr, loc) except PDLRuntimeExpressionError as exc: raise PDLRuntimeError( exc.message, loc=exc.loc or loc, trace=ErrorBlock(msg=exc.message, pdl__location=loc, program=block), ) from exc - return result + return result, expr_trace EXPR_START_STRING = "${" @@ -1112,10 +1128,26 @@ def process_condition_of( def process_expr( # pylint: disable=too-many-return-statements scope: ScopeType, expr: ExpressionType[ProcessExprT], loc: PdlLocationType -) -> ProcessExprT: +) -> tuple[ProcessExprT, LocalizedExpression[ProcessExprT]]: result: ProcessExprT if isinstance(expr, LocalizedExpression): - return process_expr(scope, expr.expr, loc) + result = _process_expr(scope, expr.expr, loc) + trace = expr.model_copy(update={"pdl__result": result}) + else: + result = _process_expr(scope, expr, loc) + trace = LocalizedExpression(expr=expr, pdl__result=result, pdl__location=loc) + return (result, trace) + + +_ProcessExprT = TypeVar("_ProcessExprT") + + +def _process_expr( # pylint: disable=too-many-return-statements + scope: ScopeType, expr: ExpressionType[_ProcessExprT], loc: PdlLocationType +) -> _ProcessExprT: + result: _ProcessExprT + if isinstance(expr, LocalizedExpression): + return _process_expr(scope, expr.expr, loc) if isinstance(expr, str): try: env = Environment( # nosec B701 @@ -1180,15 +1212,15 @@ def process_expr( # pylint: disable=too-many-return-statements if isinstance(expr, list): result_list: list[Any] = [] for index, x in enumerate(expr): - res: Any = process_expr(scope, x, append(loc, "[" + str(index) + "]")) + res: Any = _process_expr(scope, x, append(loc, "[" + str(index) + "]")) result_list.append(res) return result_list # type: ignore if isinstance(expr, dict): result_dict: dict[str, Any] = {} for k, v in expr.items(): k_loc = append(loc, k) - k_res: str = process_expr(scope, k, k_loc) - v_res: Any = process_expr(scope, v, k_loc) + k_res: str = _process_expr(scope, k, k_loc) + v_res: Any = _process_expr(scope, v, k_loc) result_dict[k_res] = v_res return result_dict # type: ignore return expr @@ -1211,7 +1243,9 @@ def process_call_model( BlockTypeTVarProcessCallModel, ]: # evaluate model name - _, concrete_block = process_expr_of(block, "model", scope, loc) # pyright: ignore + model_id, concrete_block = process_expr_of( + block, "model", scope, loc # pyright: ignore + ) # pyright: ignore # evaluate model params match concrete_block: case LitellmModelBlock(): @@ -1229,7 +1263,7 @@ def process_call_model( concrete_block.parameters, dict ): concrete_block.parameters = apply_defaults( - str(concrete_block.model), + str(model_id), concrete_block.parameters or {}, scope.get("pdl_model_default_parameters", []), ) @@ -1299,14 +1333,14 @@ def get_transformed_inputs(kwargs): scope = scope | {block.modelResponse: raw_result} return result, background, scope, trace except httpx.RequestError as exc: - message = f"model '{block.model}' encountered {repr(exc)} trying to {exc.request.method} against {exc.request.url}" + message = f"model '{model_id}' encountered {repr(exc)} trying to {exc.request.method} against {exc.request.url}" raise PDLRuntimeError( message, loc=loc, trace=ErrorBlock(msg=message, pdl__location=loc, program=concrete_block), ) from exc except Exception as exc: - message = f"Error during '{block.model}' model call: {repr(exc)}" + message = f"Error during '{model_id}' model call: {repr(exc)}" raise PDLRuntimeError( message, loc=loc, @@ -1341,15 +1375,18 @@ def generate_client_response_streaming( msg_stream: Generator[dict[str, Any], Any, Any] match block: case LitellmModelBlock(): - assert isinstance(block.model, str) # block is a "concrete block" - assert block.parameters is None or isinstance( - block.parameters, dict + if block.parameters is None: + parameters = None + else: + parameters = value_of_expr(block.parameters) # pyright: ignore + assert parameters is None or isinstance( + parameters, dict ) # block is a "concrete block" msg_stream = LitellmModel.generate_text_stream( - model_id=block.model, + model_id=value_of_expr(block.model), messages=model_input, spec=block.spec, - parameters=litellm_parameters_to_dict(block.parameters), + parameters=litellm_parameters_to_dict(parameters), ) case GraniteioModelBlock(): # TODO: curently fallback to the non-streaming interface @@ -1413,18 +1450,21 @@ def generate_client_response_single( block: LitellmModelBlock | GraniteioModelBlock, model_input: ModelInput, ) -> tuple[LazyMessage, PdlLazy[Any]]: - assert block.parameters is None or isinstance( - block.parameters, dict + if block.parameters is None: + parameters = None + else: + parameters = value_of_expr(block.parameters) # pyright:ignore + assert parameters is None or isinstance( + parameters, dict ) # block is a "concrete block" block.pdl__usage = PdlUsage() match block: case LitellmModelBlock(): - assert isinstance(block.model, str) # block is a "concrete block" - message, response = LitellmModel.generate_text( block=block, + model_id=value_of_expr(block.model), messages=model_input, - parameters=litellm_parameters_to_dict(block.parameters), + parameters=litellm_parameters_to_dict(parameters), ) case GraniteioModelBlock(): from .pdl_granite_io import GraniteioModel @@ -1455,7 +1495,15 @@ def process_call_code( code_s = "" match block: case ArgsBlock(): - code_a = [process_expr(scope, arg_i, loc) for arg_i in block.args] + code_a = [] + args_trace: list[LocalizedExpression[str]] = [] + for expr_i in block.args: + arg_i: str + trace_i: LocalizedExpression[str] + arg_i, trace_i = process_expr(scope, expr_i, loc) + code_a.append(arg_i) + args_trace.append(trace_i) + block = block.model_copy(update={"args": args_trace}) case CodeBlock(): code_, _, _, block = process_block_of( block, @@ -1846,4 +1894,5 @@ def parse_result(parser: ParserType, text: str) -> JSONReturnType: def get_var(var: str, scope: ScopeType, loc: PdlLocationType) -> Any: - return process_expr(scope, f"{EXPR_START_STRING} {var} {EXPR_END_STRING}", loc) + v, _ = process_expr(scope, f"{EXPR_START_STRING} {var} {EXPR_END_STRING}", loc) + return v diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index 3e2ce04d1..bc9302e9d 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -40,12 +40,11 @@ class LitellmModel: @staticmethod async def async_generate_text( block: LitellmModelBlock, + model_id: str, messages: ModelInput, parameters: dict[str, Any], ) -> tuple[dict[str, Any], Any]: try: - assert isinstance(block.model, str) - model_id = block.model spec = block.spec parameters = set_structured_decoding_parameters(spec, parameters) if parameters.get("mock_response") is not None: @@ -84,6 +83,7 @@ async def async_generate_text( @staticmethod def generate_text( block: LitellmModelBlock, + model_id: str, messages: ModelInput, parameters: dict[str, Any], ) -> tuple[LazyMessage, PdlLazy[Any]]: @@ -92,6 +92,7 @@ def generate_text( future = asyncio.run_coroutine_threadsafe( LitellmModel.async_generate_text( block, + model_id, messages, parameters, ), diff --git a/src/pdl/pdl_utils.py b/src/pdl/pdl_utils.py index ac93c620f..52c636b52 100644 --- a/src/pdl/pdl_utils.py +++ b/src/pdl/pdl_utils.py @@ -5,8 +5,10 @@ from .pdl_ast import ( ContributeTarget, ContributeValue, + ExpressionType, FunctionBlock, LazyMessages, + LocalizedExpression, get_sampling_defaults, ) from .pdl_lazy import lazy_apply2 @@ -63,6 +65,20 @@ def stringify(result): return s +ValueOfExprT = TypeVar("ValueOfExprT") + + +def value_of_expr(expr: ExpressionType[ValueOfExprT]) -> ValueOfExprT: + if isinstance(expr, LocalizedExpression): + if "pdl__result" in expr.model_fields_set: + v = expr.pdl__result + else: + v = expr.expr + else: + v = expr + return v # type: ignore + + def replace_contribute_value( contribute: Sequence[ContributeTarget | dict[str, ContributeValue]], value: ContributeValue,