|
| 1 | +"""AI Model Version model for version-specific configuration.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +from django.db import models |
| 8 | + |
| 9 | +from api.utils.enums import AIModelStatus |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from django.db.models import QuerySet |
| 13 | + |
| 14 | + |
| 15 | +class AIModelVersion(models.Model): |
| 16 | + """ |
| 17 | + Version of an AI Model with its own configuration. |
| 18 | + Each version can have multiple providers. |
| 19 | + """ |
| 20 | + |
| 21 | + ai_model = models.ForeignKey( |
| 22 | + "api.AIModel", |
| 23 | + on_delete=models.CASCADE, |
| 24 | + related_name="versions", |
| 25 | + ) |
| 26 | + version = models.CharField(max_length=50, help_text="Version number (e.g., 1.0.0)") |
| 27 | + version_notes = models.TextField(blank=True, help_text="Changelog/notes for this version") |
| 28 | + |
| 29 | + # Version-specific capabilities |
| 30 | + supports_streaming = models.BooleanField(default=False) |
| 31 | + max_tokens = models.IntegerField(null=True, blank=True, help_text="Maximum tokens supported") |
| 32 | + supported_languages = models.JSONField( |
| 33 | + default=list, help_text="List of supported language codes" |
| 34 | + ) |
| 35 | + input_schema = models.JSONField(default=dict, help_text="Expected input format and parameters") |
| 36 | + output_schema = models.JSONField(default=dict, help_text="Expected output format") |
| 37 | + metadata = models.JSONField(default=dict, help_text="Additional version-specific metadata") |
| 38 | + |
| 39 | + # Status |
| 40 | + status = models.CharField( |
| 41 | + max_length=20, |
| 42 | + choices=AIModelStatus.choices, |
| 43 | + default=AIModelStatus.REGISTERED, |
| 44 | + ) |
| 45 | + is_latest = models.BooleanField(default=False, help_text="Whether this is the latest version") |
| 46 | + |
| 47 | + # Timestamps |
| 48 | + created_at = models.DateTimeField(auto_now_add=True) |
| 49 | + updated_at = models.DateTimeField(auto_now=True) |
| 50 | + published_at = models.DateTimeField(null=True, blank=True) |
| 51 | + |
| 52 | + class Meta: |
| 53 | + unique_together = ["ai_model", "version"] |
| 54 | + ordering = ["-created_at"] |
| 55 | + verbose_name = "AI Model Version" |
| 56 | + verbose_name_plural = "AI Model Versions" |
| 57 | + |
| 58 | + def __str__(self): |
| 59 | + return f"{self.ai_model.name} v{self.version}" |
| 60 | + |
| 61 | + def save(self, *args, **kwargs): |
| 62 | + # If this is set as latest, unset others |
| 63 | + if self.is_latest: |
| 64 | + AIModelVersion.objects.filter(ai_model=self.ai_model, is_latest=True).exclude( |
| 65 | + pk=self.pk |
| 66 | + ).update(is_latest=False) |
| 67 | + super().save(*args, **kwargs) |
| 68 | + |
| 69 | + def copy_providers_from(self, source_version: AIModelVersion) -> None: |
| 70 | + """ |
| 71 | + Copy all providers from another version. |
| 72 | + Used when creating a new version. |
| 73 | + """ |
| 74 | + for provider in source_version.providers.all(): # type: ignore[attr-defined] |
| 75 | + # Create a copy of the provider |
| 76 | + VersionProvider.objects.create( |
| 77 | + version=self, |
| 78 | + provider=provider.provider, # type: ignore[attr-defined] |
| 79 | + provider_model_id=provider.provider_model_id, # type: ignore[attr-defined] |
| 80 | + is_primary=provider.is_primary, # type: ignore[attr-defined] |
| 81 | + is_active=provider.is_active, # type: ignore[attr-defined] |
| 82 | + hf_use_pipeline=provider.hf_use_pipeline, # type: ignore[attr-defined] |
| 83 | + hf_auth_token=provider.hf_auth_token, # type: ignore[attr-defined] |
| 84 | + hf_model_class=provider.hf_model_class, # type: ignore[attr-defined] |
| 85 | + hf_attn_implementation=provider.hf_attn_implementation, # type: ignore[attr-defined] |
| 86 | + framework=provider.framework, # type: ignore[attr-defined] |
| 87 | + config=provider.config, # type: ignore[attr-defined] |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +class VersionProvider(models.Model): |
| 92 | + """ |
| 93 | + Provider configuration for a specific version. |
| 94 | + A version can have multiple providers (HF, Custom, OpenAI, etc.) |
| 95 | + Only ONE can be primary per version. |
| 96 | + """ |
| 97 | + |
| 98 | + from api.utils.enums import AIModelFramework, AIModelProvider, HFModelClass |
| 99 | + |
| 100 | + version = models.ForeignKey( |
| 101 | + AIModelVersion, |
| 102 | + on_delete=models.CASCADE, |
| 103 | + related_name="providers", |
| 104 | + ) |
| 105 | + |
| 106 | + # Provider info |
| 107 | + provider = models.CharField(max_length=50, choices=AIModelProvider.choices) |
| 108 | + provider_model_id = models.CharField( |
| 109 | + max_length=255, |
| 110 | + blank=True, |
| 111 | + help_text="Provider's model identifier (e.g., gpt-4, claude-3-opus)", |
| 112 | + ) |
| 113 | + is_primary = models.BooleanField( |
| 114 | + default=False, help_text="Whether this is the primary provider for the version" |
| 115 | + ) |
| 116 | + is_active = models.BooleanField(default=True) |
| 117 | + |
| 118 | + # Huggingface-specific fields |
| 119 | + hf_use_pipeline = models.BooleanField(default=False, help_text="Use Pipeline inference API") |
| 120 | + hf_auth_token = models.CharField( |
| 121 | + max_length=255, |
| 122 | + blank=True, |
| 123 | + null=True, |
| 124 | + help_text="Huggingface Auth Token for gated models", |
| 125 | + ) |
| 126 | + hf_model_class = models.CharField( |
| 127 | + max_length=100, |
| 128 | + choices=HFModelClass.choices, |
| 129 | + blank=True, |
| 130 | + null=True, |
| 131 | + help_text="Specify model head to use", |
| 132 | + ) |
| 133 | + hf_attn_implementation = models.CharField( |
| 134 | + max_length=255, |
| 135 | + blank=True, |
| 136 | + default="flash_attention_2", |
| 137 | + help_text="Attention Function", |
| 138 | + ) |
| 139 | + framework = models.CharField( |
| 140 | + max_length=10, |
| 141 | + choices=AIModelFramework.choices, |
| 142 | + blank=True, |
| 143 | + null=True, |
| 144 | + help_text="Framework (PyTorch or TensorFlow)", |
| 145 | + ) |
| 146 | + |
| 147 | + # Provider-specific configuration |
| 148 | + config = models.JSONField(default=dict, help_text="Provider-specific configuration") |
| 149 | + |
| 150 | + # Timestamps |
| 151 | + created_at = models.DateTimeField(auto_now_add=True) |
| 152 | + updated_at = models.DateTimeField(auto_now=True) |
| 153 | + |
| 154 | + class Meta: |
| 155 | + ordering = ["-is_primary", "-created_at"] |
| 156 | + verbose_name = "Version Provider" |
| 157 | + verbose_name_plural = "Version Providers" |
| 158 | + |
| 159 | + def __str__(self): |
| 160 | + primary_str = " (Primary)" if self.is_primary else "" |
| 161 | + return f"{self.version} - {self.provider}{primary_str}" |
| 162 | + |
| 163 | + def save(self, *args, **kwargs): |
| 164 | + # Ensure only one primary per version |
| 165 | + if self.is_primary: |
| 166 | + VersionProvider.objects.filter(version=self.version, is_primary=True).exclude( |
| 167 | + pk=self.pk |
| 168 | + ).update(is_primary=False) |
| 169 | + super().save(*args, **kwargs) |
0 commit comments