Skip to content

Commit 9e8f139

Browse files
authored
Merge pull request #124 from mindflowai/model-org
Improve ConfiguredModel and MindFlowModel classes to be more efficient, generic, and extensible.
2 parents 5a32999 + ff9741d commit 9e8f139

File tree

9 files changed

+82
-126
lines changed

9 files changed

+82
-126
lines changed

mindflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.1"
1+
__version__ = "0.5.2"

mindflow/core/types/definitions/conversation.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,11 @@
22

33

44
class ConversationParameterKey(Enum):
5-
"""
6-
Document argument enum
7-
"""
8-
95
ID: str = "id"
106
MESSAGES: str = "messages"
117
TOTAL_TOKENS: str = "total_tokens"
128

139

1410
class ConversationID(Enum):
15-
"""
16-
Conversation ID enum
17-
"""
18-
1911
CHAT_0: str = "chat_0"
2012
CODE_GEN_0: str = "code_gen_0"

mindflow/core/types/definitions/document.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22

33

44
class DocumentType(Enum):
5-
"""
6-
Document type enum
7-
"""
8-
95
FILE: str = "file"
106
SLACK: str = "slack"
117
GITHUB: str = "github"

mindflow/core/types/definitions/mind_flow_model.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,3 @@ class MindFlowModelDescription(Enum):
6767
MindFlowModelParameterKey.DESCRIPTION.value: MindFlowModelDescription.EMBEDDING.value,
6868
},
6969
}
70-
71-
# MindFlowModelUnion = Union[
72-
# MindFlowModelID,
73-
# MindFlowModelDefaults,
74-
# MindFlowModelName,
75-
# MindFlowModelType,
76-
# MindFlowModelDescription,
77-
# ]
78-
79-
80-
# def get_mind_flow_model_static(
81-
# static: Type[MindFlowModelUnion], key: MindFlowModelUnion
82-
# ) -> MindFlowModelUnion:
83-
# return static.__members__[key.name]

mindflow/core/types/definitions/model.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,3 @@ class ModelAnthropic(Enum):
282282
ModelParameterKey.CONFIG_DESCRIPTION.value: ModelConfigDescription.TEXT_EMBEDDING_ADA_002.value,
283283
},
284284
}
285-
286-
287-
# ModelUnion = Union[
288-
# ModelID,
289-
# ModelParameterKey,
290-
# ModelName,
291-
# ModelHardTokenLimit,
292-
# ModelDescription,
293-
# ]
294-
295-
296-
# def get_model_static(static: Type[ModelUnion], key: ModelUnion) -> ModelUnion:
297-
# return static.__members__[key.name]

mindflow/core/types/definitions/object.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

mindflow/core/types/definitions/service.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,3 @@ class ServiceModel(Enum):
118118
ServiceParameterKey.API_URL.value: ServiceURL.PINECONE.value,
119119
},
120120
}
121-
122-
123-
# ServiceUnion = Union[
124-
# ServiceID,
125-
# ServiceParameterKey,
126-
# ServiceConfigParameterKey,
127-
# ServiceName,
128-
# ServiceURL,
129-
# ServiceModel,
130-
# ServiceModelTypeTextEmbedding,
131-
# ServiceModelTypeTextCompletion,
132-
# ]
133-
134-
135-
# def get_service_static(static: Type[ServiceUnion], key: ServiceUnion) -> ServiceUnion:
136-
# return static.__members__[key.name]

mindflow/core/types/mindflow_model.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import sys
2-
from typing import Dict
2+
from typing import Dict, Generic, TypeVar, cast
3+
from mindflow.core.types.definitions.model import ModelID
34
from mindflow.core.types.store_traits.static import StaticStore
45
from mindflow.core.types.store_traits.json import JsonStore
56

6-
from mindflow.core.types.model import ConfiguredModel
7+
from mindflow.core.types.model import (
8+
ConfiguredModel,
9+
ConfiguredOpenAIChatCompletionModel,
10+
ConfiguredAnthropicChatCompletionModel,
11+
ConfiguredOpenAITextEmbeddingModel,
12+
)
713
from mindflow.core.types.service import ConfiguredServices
814
from mindflow.core.types.definitions.mind_flow_model import MindFlowModelID
915
from mindflow.core.types.definitions.service import (
@@ -24,11 +30,14 @@ class MindFlowModelConfig(JsonStore):
2430
model: str
2531

2632

27-
class ConfiguredMindFlowModel:
33+
T = TypeVar("T", bound="ConfiguredModel")
34+
35+
36+
class ConfiguredMindFlowModel(Generic[T]):
2837
id: str # index, query, embedding
2938
name: str
3039
defaults: Dict[str, str]
31-
model: ConfiguredModel
40+
model: T
3241

3342
def __init__(self, mindflow_model_id: str, configured_services: ConfiguredServices):
3443
self.id = mindflow_model_id
@@ -44,7 +53,14 @@ def __init__(self, mindflow_model_id: str, configured_services: ConfiguredServic
4453
) is None:
4554
model_id = self.get_default_model_id(mindflow_model_id, configured_services)
4655

47-
self.model = ConfiguredModel(model_id)
56+
if model_id in [ModelID.GPT_3_5_TURBO.value, ModelID.GPT_4.value]:
57+
self.model = cast(T, ConfiguredOpenAIChatCompletionModel(model_id))
58+
elif model_id in [ModelID.CLAUDE_INSTANT_V1.value, ModelID.CLAUDE_V1.value]:
59+
self.model = cast(T, ConfiguredAnthropicChatCompletionModel(model_id))
60+
elif model_id == ModelID.TEXT_EMBEDDING_ADA_002.value:
61+
self.model = cast(T, ConfiguredOpenAITextEmbeddingModel(model_id))
62+
else:
63+
raise Exception("Unsupported model: " + model_id)
4864

4965
def get_default_model_id(
5066
self, mindflow_model_id: str, configured_services: ConfiguredServices

mindflow/core/types/model.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import ABC, abstractmethod
12
import time
23
from typing import Optional, Union
34

@@ -6,9 +7,6 @@
67

78
import numpy as np
89
from traitlets import Callable
9-
10-
from mindflow.core.types.definitions.model_type import ModelType
11-
1210
import tiktoken
1311

1412
from mindflow.core.types.store_traits.json import JsonStore
@@ -40,7 +38,7 @@ class ModelConfig(JsonStore):
4038
soft_token_limit: int
4139

4240

43-
class ConfiguredModel(Callable):
41+
class ConfiguredModel(ABC, Callable):
4442
id: str
4543
name: str
4644
service: str
@@ -80,7 +78,28 @@ def __init__(self, model_id: str):
8078
except NameError:
8179
pass
8280

83-
def openai_chat_completion(
81+
@abstractmethod
82+
def __call__(self, *args, **kwargs):
83+
pass
84+
85+
86+
class ConfiguredOpenAIChatCompletionModel(ConfiguredModel):
87+
id: str
88+
name: str
89+
service: str
90+
model_type: str
91+
92+
tokenizer: tiktoken.Encoding
93+
94+
hard_token_limit: int
95+
token_cost: int
96+
token_cost_unit: str
97+
98+
# Config
99+
soft_token_limit: int
100+
api_key: str
101+
102+
def __call__(
84103
self,
85104
messages: list,
86105
temperature: float = 0.0,
@@ -106,7 +125,24 @@ def openai_chat_completion(
106125

107126
return ModelError(error_message)
108127

109-
def anthropic_chat_completion(
128+
129+
class ConfiguredAnthropicChatCompletionModel(ConfiguredModel):
130+
id: str
131+
name: str
132+
service: str
133+
model_type: str
134+
135+
tokenizer: tiktoken.Encoding
136+
137+
hard_token_limit: int
138+
token_cost: int
139+
token_cost_unit: str
140+
141+
# Config
142+
soft_token_limit: int
143+
api_key: str
144+
145+
def __call__(
110146
self,
111147
prompt: str,
112148
temperature: float = 0.0,
@@ -131,7 +167,24 @@ def anthropic_chat_completion(
131167

132168
return ModelError(error_message)
133169

134-
def openai_embedding(self, text: str) -> Union[np.ndarray, ModelError]:
170+
171+
class ConfiguredOpenAITextEmbeddingModel(ConfiguredModel):
172+
id: str
173+
name: str
174+
service: str
175+
model_type: str
176+
177+
tokenizer: tiktoken.Encoding
178+
179+
hard_token_limit: int
180+
token_cost: int
181+
token_cost_unit: str
182+
183+
# Config
184+
soft_token_limit: int
185+
api_key: str
186+
187+
def __call__(self, text: str) -> Union[np.ndarray, ModelError]:
135188
try_count = 0
136189
error_message = ""
137190
while try_count < 5:
@@ -146,24 +199,3 @@ def openai_embedding(self, text: str) -> Union[np.ndarray, ModelError]:
146199
time.sleep(5)
147200

148201
return ModelError(error_message)
149-
150-
def __call__(self, prompt, *args, **kwargs):
151-
service_model_mapping = {
152-
(
153-
ServiceID.OPENAI.value,
154-
ModelType.TEXT_COMPLETION.value,
155-
): self.openai_chat_completion,
156-
(
157-
ServiceID.OPENAI.value,
158-
ModelType.TEXT_EMBEDDING.value,
159-
): self.openai_embedding,
160-
(
161-
ServiceID.ANTHROPIC.value,
162-
ModelType.TEXT_COMPLETION.value,
163-
): self.anthropic_chat_completion,
164-
}
165-
if (
166-
func := service_model_mapping.get((self.service, self.model_type))
167-
) is not None:
168-
return func(prompt, *args, **kwargs)
169-
raise NotImplementedError(f"Service {self.service} not implemented.")

0 commit comments

Comments
 (0)