|
19 | 19 | class GraniteioModel: |
20 | 20 | @staticmethod |
21 | 21 | def processor_of_block(block: GraniteioModelBlock): |
| 22 | + from granite_io import make_backend, make_io_processor |
| 23 | + from granite_io.backend.base import Backend |
| 24 | + from granite_io.io import InputOutputProcessor |
| 25 | + |
| 26 | + processor = value_of_expr(block.processor) |
| 27 | + if isinstance(processor, InputOutputProcessor): |
| 28 | + return processor |
22 | 29 | model = value_of_expr(block.model) |
23 | 30 | backend = value_of_expr(block.backend) |
24 | 31 | assert isinstance(model, str), f"The model should be a string: {model}" |
25 | | - assert isinstance( |
26 | | - backend, (dict, str) |
27 | | - ), f"The backend should be a string or a dictionary: {backend}" |
28 | 32 | match backend: |
29 | 33 | case {"transformers": device}: |
30 | | - assert isinstance(backend, dict) |
31 | | - from granite_io import make_backend |
32 | | - |
33 | 34 | backend = make_backend( |
34 | 35 | "transformers", |
35 | 36 | { |
36 | 37 | "model_name": model, |
37 | 38 | "device": device, |
38 | 39 | }, |
39 | 40 | ) |
40 | | - case backend_name if isinstance(backend_name, str): |
41 | | - from granite_io import make_backend |
42 | | - |
| 41 | + case str(): |
43 | 42 | backend = make_backend( |
44 | | - backend_name, |
| 43 | + backend, |
45 | 44 | { |
46 | 45 | "model_name": model, |
47 | 46 | }, |
48 | 47 | ) |
| 48 | + case Backend(): |
| 49 | + pass |
49 | 50 | case _: |
50 | 51 | assert False, f"Unexpected backend: {backend}" |
51 | | - if block.processor is None: |
| 52 | + if processor is None: |
52 | 53 | processor_name = model |
53 | 54 | else: |
54 | | - processor_name = value_of_expr(block.processor) |
| 55 | + assert isinstance( |
| 56 | + processor, str |
| 57 | + ), f"The processor should be a string: {processor}" |
| 58 | + processor_name = value_of_expr(processor) |
55 | 59 | assert isinstance( |
56 | 60 | processor_name, str |
57 | 61 | ), f"The processor should be a string: {processor_name}" |
58 | | - from granite_io import make_io_processor |
59 | | - |
60 | 62 | io_processor = make_io_processor(processor_name, backend=backend) |
61 | 63 | return io_processor |
62 | 64 |
|
@@ -87,7 +89,7 @@ async def async_generate_text( |
87 | 89 | inputs |
88 | 90 | ) |
89 | 91 | try: # TODO: update when new version of granite-io is released |
90 | | - message = result.next_message.model_dump() |
| 92 | + message = result.next_message.model_dump() # pyright: ignore |
91 | 93 | except AttributeError: |
92 | 94 | message = result.results[0].next_message.model_dump() |
93 | 95 | raw_result = result.model_dump() |
|
0 commit comments