Skip to content

Commit 51e66de

Browse files
committed
feat: all to pass directly backend or processor objects to granite-io blocks
Signed-off-by: Louis Mandel <[email protected]>
1 parent be61a56 commit 51e66de

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

src/pdl/pdl_ast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,10 @@ class GraniteioModelBlock(ModelBlock):
551551
model: ExpressionType[object]
552552
"""Model name used by the backend.
553553
"""
554-
backend: ExpressionType[str | dict[str, Any]]
554+
backend: ExpressionType[str | dict[str, Any] | object]
555555
"""Backend name and configuration.
556556
"""
557-
processor: Optional[ExpressionType[str]] = None
557+
processor: Optional[ExpressionType[str | object]] = None
558558
"""IO Processor name.
559559
"""
560560
parameters: Optional[ExpressionType[dict[str, Any]]] = None

src/pdl/pdl_granite_io.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,44 +19,46 @@
1919
class GraniteioModel:
2020
@staticmethod
2121
def processor_of_block(block: GraniteioModelBlock):
22+
from granite_io import make_backend, make_io_processor
23+
from granite_io.backend.base import Backend
24+
from granite_io.io import InputOutputProcessor
25+
26+
processor = value_of_expr(block.processor)
27+
if isinstance(processor, InputOutputProcessor):
28+
return processor
2229
model = value_of_expr(block.model)
2330
backend = value_of_expr(block.backend)
2431
assert isinstance(model, str), f"The model should be a string: {model}"
25-
assert isinstance(
26-
backend, (dict, str)
27-
), f"The backend should be a string or a dictionary: {backend}"
2832
match backend:
2933
case {"transformers": device}:
30-
assert isinstance(backend, dict)
31-
from granite_io import make_backend
32-
3334
backend = make_backend(
3435
"transformers",
3536
{
3637
"model_name": model,
3738
"device": device,
3839
},
3940
)
40-
case backend_name if isinstance(backend_name, str):
41-
from granite_io import make_backend
42-
41+
case str():
4342
backend = make_backend(
44-
backend_name,
43+
backend,
4544
{
4645
"model_name": model,
4746
},
4847
)
48+
case Backend():
49+
pass
4950
case _:
5051
assert False, f"Unexpected backend: {backend}"
51-
if block.processor is None:
52+
if processor is None:
5253
processor_name = model
5354
else:
54-
processor_name = value_of_expr(block.processor)
55+
assert isinstance(
56+
processor, str
57+
), f"The processor should be a string: {processor}"
58+
processor_name = value_of_expr(processor)
5559
assert isinstance(
5660
processor_name, str
5761
), f"The processor should be a string: {processor_name}"
58-
from granite_io import make_io_processor
59-
6062
io_processor = make_io_processor(processor_name, backend=backend)
6163
return io_processor
6264

@@ -87,7 +89,7 @@ async def async_generate_text(
8789
inputs
8890
)
8991
try: # TODO: update when new version of granite-io is released
90-
message = result.next_message.model_dump()
92+
message = result.next_message.model_dump() # pyright: ignore
9193
except AttributeError:
9294
message = result.results[0].next_message.model_dump()
9395
raw_result = result.model_dump()

0 commit comments

Comments
 (0)