Skip to content

Commit 32c5be6

Browse files
committed
add updated api model mutations
1 parent fa76e4c commit 32c5be6

File tree

5 files changed

+723
-3
lines changed

5 files changed

+723
-3
lines changed

api/models/AIModelVersion.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""AI Model Version model for version-specific configuration."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from django.db import models
8+
9+
from api.utils.enums import AIModelStatus
10+
11+
if TYPE_CHECKING:
12+
from django.db.models import QuerySet
13+
14+
15+
class AIModelVersion(models.Model):
16+
"""
17+
Version of an AI Model with its own configuration.
18+
Each version can have multiple providers.
19+
"""
20+
21+
ai_model = models.ForeignKey(
22+
"api.AIModel",
23+
on_delete=models.CASCADE,
24+
related_name="versions",
25+
)
26+
version = models.CharField(max_length=50, help_text="Version number (e.g., 1.0.0)")
27+
version_notes = models.TextField(blank=True, help_text="Changelog/notes for this version")
28+
29+
# Version-specific capabilities
30+
supports_streaming = models.BooleanField(default=False)
31+
max_tokens = models.IntegerField(null=True, blank=True, help_text="Maximum tokens supported")
32+
supported_languages = models.JSONField(
33+
default=list, help_text="List of supported language codes"
34+
)
35+
input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters")
36+
output_schema = models.JSONField(default=dict, help_text="Expected output format")
37+
metadata = models.JSONField(default=dict, help_text="Additional version-specific metadata")
38+
39+
# Status
40+
status = models.CharField(
41+
max_length=20,
42+
choices=AIModelStatus.choices,
43+
default=AIModelStatus.REGISTERED,
44+
)
45+
is_latest = models.BooleanField(default=False, help_text="Whether this is the latest version")
46+
47+
# Timestamps
48+
created_at = models.DateTimeField(auto_now_add=True)
49+
updated_at = models.DateTimeField(auto_now=True)
50+
published_at = models.DateTimeField(null=True, blank=True)
51+
52+
class Meta:
53+
unique_together = ["ai_model", "version"]
54+
ordering = ["-created_at"]
55+
verbose_name = "AI Model Version"
56+
verbose_name_plural = "AI Model Versions"
57+
58+
def __str__(self):
59+
return f"{self.ai_model.name} v{self.version}"
60+
61+
def save(self, *args, **kwargs):
62+
# If this is set as latest, unset others
63+
if self.is_latest:
64+
AIModelVersion.objects.filter(ai_model=self.ai_model, is_latest=True).exclude(
65+
pk=self.pk
66+
).update(is_latest=False)
67+
super().save(*args, **kwargs)
68+
69+
def copy_providers_from(self, source_version: AIModelVersion) -> None:
70+
"""
71+
Copy all providers from another version.
72+
Used when creating a new version.
73+
"""
74+
for provider in source_version.providers.all(): # type: ignore[attr-defined]
75+
# Create a copy of the provider
76+
VersionProvider.objects.create(
77+
version=self,
78+
provider=provider.provider, # type: ignore[attr-defined]
79+
provider_model_id=provider.provider_model_id, # type: ignore[attr-defined]
80+
is_primary=provider.is_primary, # type: ignore[attr-defined]
81+
is_active=provider.is_active, # type: ignore[attr-defined]
82+
hf_use_pipeline=provider.hf_use_pipeline, # type: ignore[attr-defined]
83+
hf_auth_token=provider.hf_auth_token, # type: ignore[attr-defined]
84+
hf_model_class=provider.hf_model_class, # type: ignore[attr-defined]
85+
hf_attn_implementation=provider.hf_attn_implementation, # type: ignore[attr-defined]
86+
framework=provider.framework, # type: ignore[attr-defined]
87+
config=provider.config, # type: ignore[attr-defined]
88+
)
89+
90+
91+
class VersionProvider(models.Model):
92+
"""
93+
Provider configuration for a specific version.
94+
A version can have multiple providers (HF, Custom, OpenAI, etc.)
95+
Only ONE can be primary per version.
96+
"""
97+
98+
from api.utils.enums import AIModelFramework, AIModelProvider, HFModelClass
99+
100+
version = models.ForeignKey(
101+
AIModelVersion,
102+
on_delete=models.CASCADE,
103+
related_name="providers",
104+
)
105+
106+
# Provider info
107+
provider = models.CharField(max_length=50, choices=AIModelProvider.choices)
108+
provider_model_id = models.CharField(
109+
max_length=255,
110+
blank=True,
111+
help_text="Provider's model identifier (e.g., gpt-4, claude-3-opus)",
112+
)
113+
is_primary = models.BooleanField(
114+
default=False, help_text="Whether this is the primary provider for the version"
115+
)
116+
is_active = models.BooleanField(default=True)
117+
118+
# Huggingface-specific fields
119+
hf_use_pipeline = models.BooleanField(default=False, help_text="Use Pipeline inference API")
120+
hf_auth_token = models.CharField(
121+
max_length=255,
122+
blank=True,
123+
null=True,
124+
help_text="Huggingface Auth Token for gated models",
125+
)
126+
hf_model_class = models.CharField(
127+
max_length=100,
128+
choices=HFModelClass.choices,
129+
blank=True,
130+
null=True,
131+
help_text="Specify model head to use",
132+
)
133+
hf_attn_implementation = models.CharField(
134+
max_length=255,
135+
blank=True,
136+
default="flash_attention_2",
137+
help_text="Attention Function",
138+
)
139+
framework = models.CharField(
140+
max_length=10,
141+
choices=AIModelFramework.choices,
142+
blank=True,
143+
null=True,
144+
help_text="Framework (PyTorch or TensorFlow)",
145+
)
146+
147+
# Provider-specific configuration
148+
config = models.JSONField(default=dict, help_text="Provider-specific configuration")
149+
150+
# Timestamps
151+
created_at = models.DateTimeField(auto_now_add=True)
152+
updated_at = models.DateTimeField(auto_now=True)
153+
154+
class Meta:
155+
ordering = ["-is_primary", "-created_at"]
156+
verbose_name = "Version Provider"
157+
verbose_name_plural = "Version Providers"
158+
159+
def __str__(self):
160+
primary_str = " (Primary)" if self.is_primary else ""
161+
return f"{self.version} - {self.provider}{primary_str}"
162+
163+
def save(self, *args, **kwargs):
164+
# Ensure only one primary per version
165+
if self.is_primary:
166+
VersionProvider.objects.filter(version=self.version, is_primary=True).exclude(
167+
pk=self.pk
168+
).update(is_primary=False)
169+
super().save(*args, **kwargs)

api/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from api.models.AccessModel import AccessModel, AccessModelResource
22
from api.models.AIModel import AIModel, ModelAPIKey, ModelEndpoint
3+
from api.models.AIModelVersion import AIModelVersion, VersionProvider
34
from api.models.Catalog import Catalog
45
from api.models.Collaborative import Collaborative
56
from api.models.CollaborativeMetadata import CollaborativeMetadata

0 commit comments

Comments
 (0)