|
17 | 17 | from fastapi import FastAPI |
18 | 18 | from nc_py_api import AsyncNextcloudApp, NextcloudApp, NextcloudException |
19 | 19 | from nc_py_api.ex_app import LogLvl, persistent_storage, run_app, set_handlers |
20 | | -from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider, ShapeEnumValue |
| 20 | +from nc_py_api.ex_app.providers.task_processing import ShapeDescriptor, ShapeType, TaskProcessingProvider, \ |
| 21 | + ShapeEnumValue |
21 | 22 |
|
22 | 23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()]) |
23 | 24 | logger = logging.getLogger(__name__) |
@@ -157,15 +158,18 @@ async def enabled_handler(enabled: bool, nc: AsyncNextcloudApp) -> str: |
157 | 158 | name="Local Large language Model: " + model, |
158 | 159 | task_type=task, |
159 | 160 | expected_runtime=30, |
160 | | - input_shape_enum_values= { |
161 | | - "tone": [ |
162 | | - ShapeEnumValue(name= "Friendlier", value= "friendlier"), |
163 | | - ShapeEnumValue(name= "More formal", value= "more formal"), |
164 | | - ShapeEnumValue(name= "Funnier", value= "funnier"), |
165 | | - ShapeEnumValue(name= "More casual", value= "more casual"), |
166 | | - ShapeEnumValue(name= "More urgent", value= "more urgent"), |
167 | | - ], |
168 | | - } if task == "core:text2text:changetone" else {} |
| 161 | + input_shape_enum_values= { |
| 162 | + "tone": [ |
| 163 | + ShapeEnumValue(name= "Friendlier", value= "friendlier"), |
| 164 | + ShapeEnumValue(name= "More formal", value= "more formal"), |
| 165 | + ShapeEnumValue(name= "Funnier", value= "funnier"), |
| 166 | + ShapeEnumValue(name= "More casual", value= "more casual"), |
| 167 | + ShapeEnumValue(name= "More urgent", value= "more urgent"), |
| 168 | + ], |
| 169 | + } if task == "core:text2text:changetone" else {}, |
| 170 | + optional_input_shape=[ |
| 171 | + ShapeDescriptor(name="memories", description="Memories to inject into the prompt", shape_type=ShapeType.LIST_OF_TEXTS) |
| 172 | + ] if task == "core:text2text:chat" else [], |
169 | 173 | ) |
170 | 174 | await nc.providers.task_processing.register(provider) |
171 | 175 | log(nc, LogLvl.INFO, f"Registered {task_processor_name}") |
|
0 commit comments