33import contextlib
44import dataclasses
55import typing
6+ from enum import IntEnum
7+
8+ from pydantic import RootModel
9+ from pydantic .dataclasses import dataclass
610
711from ..._exceptions import NextcloudException , NextcloudExceptionNotFound
8- from ..._misc import clear_from_params_empty , require_capabilities
12+ from ..._misc import require_capabilities
913from ..._session import AsyncNcSessionApp , NcSessionApp
1014
1115_EP_SUFFIX : str = "ai_provider/task_processing"
1216
1317
14- @dataclasses .dataclass
15- class TaskProcessingProvider :
16- """TaskProcessing provider description."""
18+ class ShapeType (IntEnum ):
19+ """Enum for shape types."""
20+
21+ NUMBER = 0
22+ TEXT = 1
23+ IMAGE = 2
24+ AUDIO = 3
25+ VIDEO = 4
26+ FILE = 5
27+ ENUM = 6
28+ LIST_OF_NUMBERS = 10
29+ LIST_OF_TEXTS = 11
30+ LIST_OF_IMAGES = 12
31+ LIST_OF_AUDIOS = 13
32+ LIST_OF_VIDEOS = 14
33+ LIST_OF_FILES = 15
34+
35+
36+ @dataclass
37+ class ShapeEnumValue :
38+ """Data object for input output shape enum slot value."""
39+
40+ name : str
41+ """Name of the enum slot value which will be displayed in the UI"""
42+ value : str
43+ """Value of the enum slot value"""
44+
1745
18- def __init__ (self , raw_data : dict ):
19- self ._raw_data = raw_data
46+ @dataclass
47+ class ShapeDescriptor :
48+ """Data object for input output shape entries."""
2049
21- @property
22- def name (self ) -> str :
23- """Unique ID for the provider."""
24- return self ._raw_data ["name" ]
50+ name : str
51+ """Name of the shape entry"""
52+ description : str
53+ """Description of the shape entry"""
54+ shape_type : ShapeType
55+ """Type of the shape entry"""
2556
26- @property
27- def display_name (self ) -> str :
28- """Providers display name."""
29- return self ._raw_data ["display_name" ]
3057
31- @property
32- def task_type (self ) -> str :
33- """The TaskType provided by this provider."""
34- return self ._raw_data ["task_type" ]
58+ @dataclass
59+ class TaskType :
60+ """TaskType description for the provider."""
61+
62+ id : str
63+ """The unique ID for the task type."""
64+ name : str
65+ """The localized name of the task type."""
66+ description : str
67+ """The localized description of the task type."""
68+ input_shape : list [ShapeDescriptor ]
69+ """The input shape of the task."""
70+ output_shape : list [ShapeDescriptor ]
71+ """The output shape of the task."""
72+
73+
74+ @dataclass
75+ class TaskProcessingProvider :
76+
77+ id : str
78+ """Unique ID for the provider."""
79+ name : str
80+ """The localized name of this provider"""
81+ task_type : str
82+ """The TaskType provided by this provider."""
83+ expected_runtime : int = dataclasses .field (default = 0 )
84+ """Expected runtime of the task in seconds."""
85+ optional_input_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
86+ """Optional input shape of the task."""
87+ optional_output_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
88+ """Optional output shape of the task."""
89+ input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
90+ """The option dict for each input shape ENUM slot."""
91+ input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
92+ """The default values for input shape slots."""
93+ optional_input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
94+ """The option list for each optional input shape ENUM slot."""
95+ optional_input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
96+ """The default values for optional input shape slots."""
97+ output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
98+ """The option list for each output shape ENUM slot."""
99+ optional_output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
100+ """The option list for each optional output shape ENUM slot."""
35101
36102 def __repr__ (self ):
37103 return f"<{ self .__class__ .__name__ } name={ self .name } , type={ self .task_type } >"
@@ -44,17 +110,16 @@ def __init__(self, session: NcSessionApp):
44110 self ._session = session
45111
46112 def register (
47- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
113+ self ,
114+ provider : TaskProcessingProvider ,
115+ custom_task_type : TaskType | None = None ,
48116 ) -> None :
49117 """Registers or edit the TaskProcessing provider."""
50118 require_capabilities ("app_api" , self ._session .capabilities )
51119 params = {
52- "name" : name ,
53- "displayName" : display_name ,
54- "taskType" : task_type ,
55- "customTaskType" : custom_task_type ,
120+ "provider" : RootModel (provider ).model_dump (),
121+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
56122 }
57- clear_from_params_empty (["customTaskType" ], params )
58123 self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
59124
60125 def unregister (self , name : str , not_fail = True ) -> None :
@@ -123,17 +188,16 @@ def __init__(self, session: AsyncNcSessionApp):
123188 self ._session = session
124189
125190 async def register (
126- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
191+ self ,
192+ provider : TaskProcessingProvider ,
193+ custom_task_type : TaskType | None = None ,
127194 ) -> None :
128195 """Registers or edit the TaskProcessing provider."""
129196 require_capabilities ("app_api" , await self ._session .capabilities )
130197 params = {
131- "name" : name ,
132- "displayName" : display_name ,
133- "taskType" : task_type ,
134- "customTaskType" : custom_task_type ,
198+ "provider" : RootModel (provider ).model_dump (),
199+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
135200 }
136- clear_from_params_empty (["customTaskType" ], params )
137201 await self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
138202
139203 async def unregister (self , name : str , not_fail = True ) -> None :
0 commit comments