1+ from abc import ABC , abstractmethod
12import time
23from typing import Optional , Union
34
67
78import numpy as np
89from traitlets import Callable
9-
10- from mindflow .core .types .definitions .model_type import ModelType
11-
1210import tiktoken
1311
1412from mindflow .core .types .store_traits .json import JsonStore
@@ -40,7 +38,7 @@ class ModelConfig(JsonStore):
4038 soft_token_limit : int
4139
4240
43- class ConfiguredModel (Callable ):
41+ class ConfiguredModel (ABC , Callable ):
4442 id : str
4543 name : str
4644 service : str
@@ -80,7 +78,28 @@ def __init__(self, model_id: str):
8078 except NameError :
8179 pass
8280
83- def openai_chat_completion (
81+ @abstractmethod
82+ def __call__ (self , * args , ** kwargs ):
83+ pass
84+
85+
86+ class ConfiguredOpenAIChatCompletionModel (ConfiguredModel ):
87+ id : str
88+ name : str
89+ service : str
90+ model_type : str
91+
92+ tokenizer : tiktoken .Encoding
93+
94+ hard_token_limit : int
95+ token_cost : int
96+ token_cost_unit : str
97+
98+ # Config
99+ soft_token_limit : int
100+ api_key : str
101+
102+ def __call__ (
84103 self ,
85104 messages : list ,
86105 temperature : float = 0.0 ,
@@ -106,7 +125,24 @@ def openai_chat_completion(
106125
107126 return ModelError (error_message )
108127
109- def anthropic_chat_completion (
128+
129+ class ConfiguredAnthropicChatCompletionModel (ConfiguredModel ):
130+ id : str
131+ name : str
132+ service : str
133+ model_type : str
134+
135+ tokenizer : tiktoken .Encoding
136+
137+ hard_token_limit : int
138+ token_cost : int
139+ token_cost_unit : str
140+
141+ # Config
142+ soft_token_limit : int
143+ api_key : str
144+
145+ def __call__ (
110146 self ,
111147 prompt : str ,
112148 temperature : float = 0.0 ,
@@ -131,7 +167,24 @@ def anthropic_chat_completion(
131167
132168 return ModelError (error_message )
133169
134- def openai_embedding (self , text : str ) -> Union [np .ndarray , ModelError ]:
170+
171+ class ConfiguredOpenAITextEmbeddingModel (ConfiguredModel ):
172+ id : str
173+ name : str
174+ service : str
175+ model_type : str
176+
177+ tokenizer : tiktoken .Encoding
178+
179+ hard_token_limit : int
180+ token_cost : int
181+ token_cost_unit : str
182+
183+ # Config
184+ soft_token_limit : int
185+ api_key : str
186+
187+ def __call__ (self , text : str ) -> Union [np .ndarray , ModelError ]:
135188 try_count = 0
136189 error_message = ""
137190 while try_count < 5 :
@@ -146,24 +199,3 @@ def openai_embedding(self, text: str) -> Union[np.ndarray, ModelError]:
146199 time .sleep (5 )
147200
148201 return ModelError (error_message )
149-
150- def __call__ (self , prompt , * args , ** kwargs ):
151- service_model_mapping = {
152- (
153- ServiceID .OPENAI .value ,
154- ModelType .TEXT_COMPLETION .value ,
155- ): self .openai_chat_completion ,
156- (
157- ServiceID .OPENAI .value ,
158- ModelType .TEXT_EMBEDDING .value ,
159- ): self .openai_embedding ,
160- (
161- ServiceID .ANTHROPIC .value ,
162- ModelType .TEXT_COMPLETION .value ,
163- ): self .anthropic_chat_completion ,
164- }
165- if (
166- func := service_model_mapping .get ((self .service , self .model_type ))
167- ) is not None :
168- return func (prompt , * args , ** kwargs )
169- raise NotImplementedError (f"Service { self .service } not implemented." )
0 commit comments