11from datetime import datetime
22import json
33import os
4+ import shutil
45import traceback
56import uuid
67
78from fastapi import APIRouter , HTTPException , UploadFile , File , Form , BackgroundTasks
89from fastapi import Request
910from starlette .responses import JSONResponse
10- from pydantic import BaseModel , field_validator
11+ from pydantic import BaseModel , field_validator , model_validator
1112
1213from ..utilities .finetuning .CausalLLMTuner import CausalLLMFinetuner
1314from ..utilities .finetuning .QuestionAnsweringTuner import QuestionAnsweringTuner
2021 prefix = "/finetune" ,
2122)
2223
24+ # Valid task types
25+ VALID_TASKS = ["text-generation" , "summarization" , "extractive-question-answering" ]
26+ VALID_TASKS_STR = "'text-generation', 'summarization', or 'extractive-question-answering'"
27+
2328## Pydantic Data Validator Classes
2429class TaskFormData (BaseModel ):
2530 task : str
2631 @field_validator ("task" )
2732 def validate_task (cls , task ):
28- if task not in [ "text-generation" , "summarization" , "extractive-question-answering" ] :
29- raise ValueError ("Invalid task. Must be one of 'text-generation', 'summarization', or 'extractive-question-answering' ." )
33+ if task not in VALID_TASKS :
34+ raise ValueError (f "Invalid task. Must be one of { VALID_TASKS_STR } ." )
3035 return task
3136
3237class SelectedModelFormData (BaseModel ):
@@ -83,8 +88,8 @@ def validate_dataset_prescence(cls, dataset):
8388 return dataset
8489 @field_validator ("task" )
8590 def validate_task (cls , task ):
86- if task not in [ "text-generation" , "summarization" , "question-answering" ] :
87- raise ValueError ("Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering' ." )
91+ if task not in VALID_TASKS :
92+ raise ValueError (f "Invalid task. Must be one of { VALID_TASKS_STR } ." )
8893 return task
8994 @field_validator ("model_name" )
9095 def validate_model_name (cls , model_name ):
@@ -95,7 +100,7 @@ def validate_model_name(cls, model_name):
95100 def validate_num_train_epochs (cls , num_train_epochs ):
96101 if num_train_epochs <= 0 :
97102 raise ValueError ("Number of training epochs must be greater than 0." )
98- elif num_train_epochs > 30 :
103+ if num_train_epochs >= 50 :
99104 raise ValueError ("Number of training epochs must be less than 50." )
100105 return num_train_epochs
101106 @field_validator ("compute_specs" )
@@ -154,13 +159,9 @@ def validate_bf16(cls, bf16):
154159 raise ValueError ("bf16 must be true or false." )
155160 return bf16
156161 @field_validator ("per_device_train_batch_size" )
157- def validate_per_device_train_batch_size (cls , per_device_train_batch_size , compute_specs ):
162+ def validate_per_device_train_batch_size (cls , per_device_train_batch_size ):
158163 if per_device_train_batch_size <= 0 :
159164 raise ValueError ("Batch size must be greater than 0." )
160- elif per_device_train_batch_size > 3 and compute_specs != "high_end" :
161- raise ValueError ("Batch size must be less than 4. Your device cannot support a higher batch size." )
162- elif per_device_train_batch_size > 8 and compute_specs == "high_end" :
163- raise ValueError ("Batch size must be less than 9. Higher batch sizes cause out of memory error." )
164165 return per_device_train_batch_size
165166 @field_validator ("per_device_eval_batch_size" )
166167 def validate_per_device_eval_batch_size (cls , per_device_eval_batch_size ):
@@ -236,11 +237,20 @@ def validate_dataset(cls, dataset):
236237 if not dataset :
237238 raise ValueError ("Dataset cannot be empty." )
238239 return dataset
240+
241+ @model_validator (mode = 'after' )
242+ def validate_batch_size_with_compute_specs (self ):
243+ """Validate batch size based on compute specs"""
244+ if self .per_device_train_batch_size > 3 and self .compute_specs != "high_end" :
245+ raise ValueError ("Batch size must be 3 or less. Your device cannot support a higher batch size." )
246+ elif self .per_device_train_batch_size > 8 and self .compute_specs == "high_end" :
247+ raise ValueError ("Batch size must be 8 or less. Higher batch sizes cause out of memory error." )
248+ return self
239249
240250
241251@router .get ("/detect" )
242252async def detect_hardware_page (request : Request ) -> JSONResponse :
243- global_manager .clear_global_manager . settings_cache () # Clear the cache to ensure fresh detection
253+ global_manager .clear_settings_cache () # Clear the cache to ensure fresh detection
244254 return JSONResponse ({
245255 "app_name" : global_manager .app_name ,
246256 "message" : "Ready to detect hardware"
@@ -290,7 +300,7 @@ async def detect_hardware(request: Request) -> JSONResponse:
290300 print (e )
291301 raise HTTPException (
292302 status_code = 400 ,
293- detail = "Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering' ."
303+ detail = f "Invalid task. Must be one of { VALID_TASKS_STR } ."
294304 )
295305 except Exception as e :
296306 print ("General exception triggered" )
@@ -317,6 +327,14 @@ async def set_model(request: Request) -> None:
317327
318328@router .post ("/validate_custom_model" , response_class = JSONResponse )
319329async def validate_custom_model (request : Request ) -> JSONResponse :
330+ """
331+ Validate a custom model from HuggingFace Hub.
332+
333+ Note: Currently validates repository existence but not task compatibility.
334+ Consider adding architecture-task compatibility checks (e.g., ensure
335+ summarization models aren't used for text generation tasks) for better
336+ user experience and to prevent fine-tuning failures.
337+ """
320338 try :
321339 form = await request .json ()
322340 validation_data = CustomModelValidationData (repo_name = form ["repo_name" ])
@@ -424,15 +442,14 @@ async def load_settings(json_file: UploadFile = File(...), settings: str = Form(
424442
425443
426444def finetuning_task (llm_tuner ) -> None :
445+ output_dir = None
427446 try :
428447 llm_tuner .load_dataset (global_manager .settings_builder .dataset )
448+ output_dir = llm_tuner .output_dir # Store for cleanup on failure
429449 path = llm_tuner .finetune ()
430450
431- # Handle both absolute and relative paths
432- if os .path .isabs (path ):
433- model_path = path
434- else :
435- model_path = os .path .join (os .path .dirname (__file__ ), path .replace ("./" , "" ))
451+ # Use the path returned from finetune (should be absolute)
452+ model_path = os .path .abspath (path ) if not os .path .isabs (path ) else path
436453
437454 model_data = {
438455 "model_name" : global_manager .settings_builder .fine_tuned_name .split ('/' )[- 1 ] if global_manager .settings_builder .fine_tuned_name else os .path .basename (model_path ),
@@ -445,6 +462,17 @@ def finetuning_task(llm_tuner) -> None:
445462 "is_custom_base_model" : global_manager .settings_builder .is_custom_model
446463 }
447464 global_manager .db_manager .add_model (model_data )
465+
466+ except Exception as e :
467+ print (f"Fine-tuning failed: { e } " )
468+ # Cleanup failed fine-tuning artifacts
469+ if output_dir and os .path .exists (output_dir ):
470+ try :
471+ shutil .rmtree (output_dir )
472+ print (f"Cleaned up failed fine-tuning artifacts at: { output_dir } " )
473+ except Exception as cleanup_error :
474+ print (f"Warning: Could not cleanup output directory: { cleanup_error } " )
475+ raise
448476
449477 finally :
450478 global_manager .settings_cache .clear ()
@@ -485,6 +513,19 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
485513 status_code = 400 ,
486514 detail = "A finetuning is already in progress. Please wait until it completes."
487515 )
516+
517+ # Validate available disk space (require at least 10GB free)
518+ try :
519+ stat = shutil .disk_usage (global_manager .model_path )
520+ available_gb = stat .free / (1024 ** 3 ) # Convert to GB
521+ if available_gb < 10 :
522+ raise HTTPException (
523+ status_code = 400 ,
524+ detail = f"Insufficient disk space. Available: { available_gb :.2f} GB. Required: at least 10GB."
525+ )
526+ except Exception as e :
527+ print (f"Warning: Could not check disk space: { e } " )
528+
488529 global_manager .finetuning_status ["status" ] = "initializing"
489530 global_manager .finetuning_status ["message" ] = "Starting finetuning process..."
490531 if global_manager .settings_builder .task == "text-generation" :
@@ -508,7 +549,7 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
508549 else :
509550 raise HTTPException (
510551 status_code = 400 ,
511- detail = "Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering' ."
552+ detail = f "Invalid task. Must be one of { VALID_TASKS_STR } ."
512553 )
513554
514555 llm_tuner .set_settings (** global_manager .settings_builder .get_settings ())
0 commit comments