Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pdl-live-react/src/pdl_ast.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,11 @@ export type PdlContext17 =
export type PdlId17 = string | null
export type PdlIsLeaf17 = true
export type Kind17 = "model"
/**
* Model name used by the backend.
*
*/
export type Model = LocalizedExpression | string
/**
* Messages to send to the model.
*
Expand Down Expand Up @@ -2411,11 +2416,6 @@ export type Backend =
| {
[k: string]: unknown
}
/**
* IO Processor name.
*
*/
export type Processor = LocalizedExpression | string | null
/**
* Parameters sent to the model.
*
Expand Down Expand Up @@ -3210,7 +3210,7 @@ export interface GraniteioModelBlock {
pdl__timing?: PdlTiming | null
pdl__is_leaf?: PdlIsLeaf17
kind?: Kind17
model: unknown
model: Model
input?: Input
modelResponse?: Modelresponse
/**
Expand All @@ -3221,7 +3221,7 @@ export interface GraniteioModelBlock {
pdl__model_input?: PdlModelInput
platform?: Platform
backend: Backend
processor?: Processor
processor?: unknown
parameters?: Parameters
}
/**
Expand Down
7 changes: 2 additions & 5 deletions pdl-live-react/src/pdl_ast_utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { match, P } from "ts-pattern"

import { Backend, PdlBlock, Processor } from "./pdl_ast"
import { Backend, PdlBlock } from "./pdl_ast"
import { ExpressionT, isArgs } from "./helpers"

export function map_block_children(
Expand Down Expand Up @@ -74,10 +74,7 @@ export function map_block_children(
: undefined
// @ts-expect-error: f_expr does not preserve the type of the expression
const backend: Backend = f_expr(block.backend)
// @ts-expect-error: f_expr does not preserve the type of the expression
const processor: Processor = block.processor
? f_expr(block.processor)
: undefined
const processor = block.processor ? f_expr(block.processor) : undefined
return {
...block,
model,
Expand Down
2 changes: 1 addition & 1 deletion src/pdl/pdl-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -4461,7 +4461,6 @@
{
"$ref": "#/$defs/LocalizedExpression_TypeVar_"
},
{},
{
"type": "string"
}
Expand Down Expand Up @@ -4626,6 +4625,7 @@
{
"type": "string"
},
{},
{
"type": "null"
}
Expand Down
6 changes: 3 additions & 3 deletions src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,13 @@ class GraniteioModelBlock(ModelBlock):
platform: Literal[ModelPlatform.GRANITEIO] = ModelPlatform.GRANITEIO
"""Optional field to ensure that the block is using granite-io.
"""
model: ExpressionType[object]
model: ExpressionType[str]
"""Model name used by the backend.
"""
backend: ExpressionType[str | dict[str, Any]]
backend: ExpressionType[str | dict[str, Any | object]]
"""Backend name and configuration.
"""
processor: Optional[ExpressionType[str]] = None
processor: Optional[ExpressionType[str | object]] = None
"""IO Processor name.
"""
parameters: Optional[ExpressionType[dict[str, Any]]] = None
Expand Down
34 changes: 18 additions & 16 deletions src/pdl/pdl_granite_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,46 @@
class GraniteioModel:
@staticmethod
def processor_of_block(block: GraniteioModelBlock):
model = value_of_expr(block.model)
from granite_io import make_backend, make_io_processor
from granite_io.backend.base import Backend
from granite_io.io import InputOutputProcessor

processor = value_of_expr(block.processor)
if isinstance(processor, InputOutputProcessor):
return processor
model: str = value_of_expr(block.model)
backend = value_of_expr(block.backend)
assert isinstance(model, str), f"The model should be a string: {model}"
assert isinstance(
backend, (dict, str)
), f"The backend should be a string or a dictionary: {backend}"
match backend:
case {"transformers": device}:
assert isinstance(backend, dict)
from granite_io import make_backend

backend = make_backend(
"transformers",
{
"model_name": model,
"device": device,
},
)
case backend_name if isinstance(backend_name, str):
from granite_io import make_backend

case str():
backend = make_backend(
backend_name,
backend,
{
"model_name": model,
},
)
case Backend():
pass
case _:
assert False, f"Unexpected backend: {backend}"
if block.processor is None:
if processor is None:
processor_name = model
else:
processor_name = value_of_expr(block.processor)
assert isinstance(
processor, str
), f"The processor should be a string: {processor}"
processor_name = value_of_expr(processor)
assert isinstance(
processor_name, str
), f"The processor should be a string: {processor_name}"
from granite_io import make_io_processor

io_processor = make_io_processor(processor_name, backend=backend)
return io_processor

Expand Down Expand Up @@ -87,7 +89,7 @@ async def async_generate_text(
inputs
)
try: # TODO: update when new version of granite-io is released
message = result.next_message.model_dump()
message = result.next_message.model_dump() # pyright: ignore
except AttributeError:
message = result.results[0].next_message.model_dump()
raw_result = result.model_dump()
Expand Down