55from django .db import models
66
77from api .utils .enums import (
8+ AIModelFramework ,
89 AIModelProvider ,
910 AIModelStatus ,
1011 AIModelType ,
1112 EndpointAuthType ,
1213 EndpointHTTPMethod ,
14+ HFModelClass ,
1315)
1416
1517User = get_user_model ()
@@ -36,6 +38,35 @@ class AIModel(models.Model):
3638 help_text = "Provider's model identifier (e.g., gpt-4, claude-3-opus)" ,
3739 )
3840
41+ # Huggingface Models
42+ hf_use_pipeline = models .BooleanField (default = False , help_text = "Use Pipeline inference API" )
43+ hf_auth_token = models .CharField (
44+ max_length = 255 ,
45+ blank = True ,
46+ null = True ,
47+ help_text = "Huggingface Auth Token for gated models" ,
48+ )
49+ hf_model_class = models .CharField (
50+ max_length = 100 ,
51+ choices = HFModelClass .choices ,
52+ blank = True ,
53+ null = True ,
54+ help_text = "Specify model head to use" ,
55+ )
56+ hf_attn_implementation = models .CharField (
57+ max_length = 255 ,
58+ blank = True ,
59+ default = "flash_attention_2" ,
60+ help_text = "Attention Function" ,
61+ )
62+ framework = models .CharField (
63+ max_length = 10 ,
64+ choices = AIModelFramework .choices ,
65+ blank = True ,
66+ null = True ,
67+ help_text = "Framework (PyTorch or TensorFlow)" ,
68+ )
69+
3970 # Ownership & Organization
4071 organization = models .ForeignKey (
4172 "api.Organization" ,
@@ -65,17 +96,13 @@ class AIModel(models.Model):
6596 )
6697
6798 # Input/Output Schema
68- input_schema = models .JSONField (
69- default = dict , help_text = "Expected input format and parameters"
70- )
99+ input_schema = models .JSONField (default = dict , help_text = "Expected input format and parameters" )
71100 output_schema = models .JSONField (default = dict , help_text = "Expected output format" )
72101
73102 # Metadata
74103 tags = models .ManyToManyField ("api.Tag" , blank = True )
75104 sectors = models .ManyToManyField ("api.Sector" , blank = True , related_name = "ai_models" )
76- geographies = models .ManyToManyField (
77- "api.Geography" , blank = True , related_name = "ai_models"
78- )
105+ geographies = models .ManyToManyField ("api.Geography" , blank = True , related_name = "ai_models" )
79106 metadata = models .JSONField (
80107 default = dict ,
81108 help_text = "Additional metadata (training data info, limitations, etc.)" ,
@@ -151,14 +178,10 @@ class ModelEndpoint(models.Model):
151178 Supports multiple endpoints per model (e.g., different regions, fallbacks)
152179 """
153180
154- model = models .ForeignKey (
155- AIModel , on_delete = models .CASCADE , related_name = "endpoints"
156- )
181+ model = models .ForeignKey (AIModel , on_delete = models .CASCADE , related_name = "endpoints" )
157182
158183 # Endpoint Configuration
159- url = models .URLField (
160- max_length = 500 , validators = [URLValidator ()], help_text = "API endpoint URL"
161- )
184+ url = models .URLField (max_length = 500 , validators = [URLValidator ()], help_text = "API endpoint URL" )
162185 http_method = models .CharField (
163186 max_length = 10 ,
164187 choices = EndpointHTTPMethod .choices ,
@@ -176,9 +199,7 @@ class ModelEndpoint(models.Model):
176199 )
177200
178201 # Request Configuration
179- headers = models .JSONField (
180- default = dict , help_text = "Additional headers to include in requests"
181- )
202+ headers = models .JSONField (default = dict , help_text = "Additional headers to include in requests" )
182203 request_template = models .JSONField (
183204 default = dict , help_text = "Template for request body with placeholders"
184205 )
@@ -222,19 +243,15 @@ def success_rate(self):
222243 """Calculate success rate"""
223244 if self .total_requests == 0 :
224245 return None
225- return (
226- (self .total_requests - self .failed_requests ) / self .total_requests
227- ) * 100
246+ return ((self .total_requests - self .failed_requests ) / self .total_requests ) * 100
228247
229248
230249class ModelAPIKey (models .Model ):
231250 """
232251 Encrypted storage for API keys/credentials for model endpoints.
233252 """
234253
235- model = models .ForeignKey (
236- AIModel , on_delete = models .CASCADE , related_name = "api_keys"
237- )
254+ model = models .ForeignKey (AIModel , on_delete = models .CASCADE , related_name = "api_keys" )
238255
239256 name = models .CharField (max_length = 100 , help_text = "Friendly name for this API key" )
240257
0 commit comments