diff --git a/pdl-live-react/src/pdl_ast.d.ts b/pdl-live-react/src/pdl_ast.d.ts index 1e3bec9c6..12f8123d3 100644 --- a/pdl-live-react/src/pdl_ast.d.ts +++ b/pdl-live-react/src/pdl_ast.d.ts @@ -2354,6 +2354,11 @@ export type PdlContext17 = export type PdlId17 = string | null export type PdlIsLeaf17 = true export type Kind17 = "model" +/** + * Model name used by the backend. + * + */ +export type Model = LocalizedExpression | string /** * Messages to send to the model. * @@ -2411,11 +2416,6 @@ export type Backend = | { [k: string]: unknown } -/** - * IO Processor name. - * - */ -export type Processor = LocalizedExpression | string | null /** * Parameters sent to the model. * @@ -3210,7 +3210,7 @@ export interface GraniteioModelBlock { pdl__timing?: PdlTiming | null pdl__is_leaf?: PdlIsLeaf17 kind?: Kind17 - model: unknown + model: Model input?: Input modelResponse?: Modelresponse /** @@ -3221,7 +3221,7 @@ export interface GraniteioModelBlock { pdl__model_input?: PdlModelInput platform?: Platform backend: Backend - processor?: Processor + processor?: unknown parameters?: Parameters } /** diff --git a/pdl-live-react/src/pdl_ast_utils.ts b/pdl-live-react/src/pdl_ast_utils.ts index 8326e798c..541c3c048 100644 --- a/pdl-live-react/src/pdl_ast_utils.ts +++ b/pdl-live-react/src/pdl_ast_utils.ts @@ -1,6 +1,6 @@ import { match, P } from "ts-pattern" -import { Backend, PdlBlock, Processor } from "./pdl_ast" +import { Backend, PdlBlock } from "./pdl_ast" import { ExpressionT, isArgs } from "./helpers" export function map_block_children( @@ -74,10 +74,7 @@ export function map_block_children( : undefined // @ts-expect-error: f_expr does not preserve the type of the expression const backend: Backend = f_expr(block.backend) - // @ts-expect-error: f_expr does not preserve the type of the expression - const processor: Processor = block.processor - ? f_expr(block.processor) - : undefined + const processor = block.processor ? f_expr(block.processor) : undefined return { ...block, model, diff --git a/src/pdl/pdl-schema.json b/src/pdl/pdl-schema.json index 74b2e5775..bd4e9b2c2 100644 --- a/src/pdl/pdl-schema.json +++ b/src/pdl/pdl-schema.json @@ -4461,7 +4461,6 @@ { "$ref": "#/$defs/LocalizedExpression_TypeVar_" }, - {}, { "type": "string" } @@ -4626,6 +4625,7 @@ { "type": "string" }, + {}, { "type": "null" } diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index 577518d27..b5294ab04 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -548,13 +548,13 @@ class GraniteioModelBlock(ModelBlock): platform: Literal[ModelPlatform.GRANITEIO] = ModelPlatform.GRANITEIO """Optional field to ensure that the block is using granite-io. """ - model: ExpressionType[object] + model: ExpressionType[str] """Model name used by the backend. """ - backend: ExpressionType[str | dict[str, Any]] + backend: ExpressionType[str | dict[str, Any | object]] """Backend name and configuration. """ - processor: Optional[ExpressionType[str]] = None + processor: Optional[ExpressionType[str | object]] = None """IO Processor name. """ parameters: Optional[ExpressionType[dict[str, Any]]] = None diff --git a/src/pdl/pdl_granite_io.py b/src/pdl/pdl_granite_io.py index adce5d309..4690e11ee 100644 --- a/src/pdl/pdl_granite_io.py +++ b/src/pdl/pdl_granite_io.py @@ -19,17 +19,18 @@ class GraniteioModel: @staticmethod def processor_of_block(block: GraniteioModelBlock): - model = value_of_expr(block.model) + from granite_io import make_backend, make_io_processor + from granite_io.backend.base import Backend + from granite_io.io import InputOutputProcessor + + processor = value_of_expr(block.processor) + if isinstance(processor, InputOutputProcessor): + return processor + model: str = value_of_expr(block.model) backend = value_of_expr(block.backend) assert isinstance(model, str), f"The model should be a string: {model}" - assert isinstance( - backend, (dict, str) - ), f"The backend should be a string or a dictionary: {backend}" match backend: case {"transformers": device}: - assert isinstance(backend, dict) - from granite_io import make_backend - backend = make_backend( "transformers", { @@ -37,26 +38,27 @@ def processor_of_block(block: GraniteioModelBlock): "device": device, }, ) - case backend_name if isinstance(backend_name, str): - from granite_io import make_backend - + case str(): backend = make_backend( - backend_name, + backend, { "model_name": model, }, ) + case Backend(): + pass case _: assert False, f"Unexpected backend: {backend}" - if block.processor is None: + if processor is None: processor_name = model else: - processor_name = value_of_expr(block.processor) + assert isinstance( + processor, str + ), f"The processor should be a string: {processor}" + processor_name = value_of_expr(processor) assert isinstance( processor_name, str ), f"The processor should be a string: {processor_name}" - from granite_io import make_io_processor - io_processor = make_io_processor(processor_name, backend=backend) return io_processor @@ -87,7 +89,7 @@ async def async_generate_text( inputs ) try: # TODO: update when new version of granite-io is released - message = result.next_message.model_dump() + message = result.next_message.model_dump() # pyright: ignore except AttributeError: message = result.results[0].next_message.model_dump() raw_result = result.model_dump()