Skip to content
Merged
7 changes: 4 additions & 3 deletions ModelForge/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
## Static files
frontend_dir = os.path.join(os.path.dirname(__file__), "./Frontend/build")
app_name = "ModelForge"
origins = [
"http://localhost:8000",
]

# CORS origins - configurable via environment variable
cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:8000")
origins = [origin.strip() for origin in cors_origins_env.split(",")]

app.add_middleware(
CORSMiddleware,
Expand Down
29 changes: 19 additions & 10 deletions ModelForge/globals/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class GlobalSettings:
_instance = None
_initialized = False
file_manager = None
hardware_detector = None
settings_builder = None
Expand All @@ -23,19 +24,27 @@ def __new__(cls):
return cls._instance

def __init__(self):
self.file_manager = FileManager()
self.hardware_detector = HardwareDetector()
self.settings_builder = SettingsBuilder(None, None, None)
self.settings_cache = {}
self.finetuning_status = {"status": "idle", "progress": 0, "message": ""}
self.datasets_dir = self.file_manager.return_default_dirs()["datasets"]
self.model_path = self.file_manager.return_default_dirs()["models"]
self.db_manager = DatabaseManager(db_path=os.path.join(self.file_manager.return_default_dirs()["database"], "modelforge.sqlite"))
self.app_name = "ModelForge"
# Only initialize once to maintain singleton behavior
if not GlobalSettings._initialized:
self.file_manager = FileManager()
self.hardware_detector = HardwareDetector()
self.settings_builder = SettingsBuilder(None, None, None)
self.settings_cache = {}
# NOTE: finetuning_status is accessed from multiple places (callback, background task)
# without locking. Python's GIL provides basic protection, but be cautious with
# complex operations. Consider using threading.Lock if race conditions occur.
self.finetuning_status = {"status": "idle", "progress": 0, "message": ""}
self.datasets_dir = self.file_manager.return_default_dirs()["datasets"]
self.model_path = self.file_manager.return_default_dirs()["models"]
self.db_manager = DatabaseManager(db_path=os.path.join(self.file_manager.return_default_dirs()["database"], "modelforge.sqlite"))
self.app_name = "ModelForge"
GlobalSettings._initialized = True

@classmethod
def get_instance(cls):
return cls.__new__(cls)
if cls._instance is None:
cls._instance = cls()
return cls._instance

def clear_settings_cache(self):
self.settings_cache.clear()
Expand Down
79 changes: 60 additions & 19 deletions ModelForge/routers/finetuning_router.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datetime import datetime
import json
import os
import shutil
import traceback
import uuid

from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
from fastapi import Request
from starlette.responses import JSONResponse
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator

from ..utilities.finetuning.CausalLLMTuner import CausalLLMFinetuner
from ..utilities.finetuning.QuestionAnsweringTuner import QuestionAnsweringTuner
Expand All @@ -20,13 +21,17 @@
prefix="/finetune",
)

# Valid task types
VALID_TASKS = ["text-generation", "summarization", "extractive-question-answering"]
VALID_TASKS_STR = "'text-generation', 'summarization', or 'extractive-question-answering'"

## Pydantic Data Validator Classes
class TaskFormData(BaseModel):
task: str
@field_validator("task")
def validate_task(cls, task):
if task not in ["text-generation", "summarization", "extractive-question-answering"]:
raise ValueError("Invalid task. Must be one of 'text-generation', 'summarization', or 'extractive-question-answering'.")
if task not in VALID_TASKS:
raise ValueError(f"Invalid task. Must be one of {VALID_TASKS_STR}.")
return task

class SelectedModelFormData(BaseModel):
Expand Down Expand Up @@ -83,8 +88,8 @@ def validate_dataset_prescence(cls, dataset):
return dataset
@field_validator("task")
def validate_task(cls, task):
if task not in ["text-generation", "summarization", "question-answering"]:
raise ValueError("Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'.")
if task not in VALID_TASKS:
raise ValueError(f"Invalid task. Must be one of {VALID_TASKS_STR}.")
return task
@field_validator("model_name")
def validate_model_name(cls, model_name):
Expand All @@ -95,7 +100,7 @@ def validate_model_name(cls, model_name):
def validate_num_train_epochs(cls, num_train_epochs):
if num_train_epochs <= 0:
raise ValueError("Number of training epochs must be greater than 0.")
elif num_train_epochs > 30:
if num_train_epochs >= 50:
raise ValueError("Number of training epochs must be less than 50.")
return num_train_epochs
@field_validator("compute_specs")
Expand Down Expand Up @@ -154,13 +159,9 @@ def validate_bf16(cls, bf16):
raise ValueError("bf16 must be true or false.")
return bf16
@field_validator("per_device_train_batch_size")
def validate_per_device_train_batch_size(cls, per_device_train_batch_size, compute_specs):
def validate_per_device_train_batch_size(cls, per_device_train_batch_size):
if per_device_train_batch_size <= 0:
raise ValueError("Batch size must be greater than 0.")
elif per_device_train_batch_size > 3 and compute_specs != "high_end":
raise ValueError("Batch size must be less than 4. Your device cannot support a higher batch size.")
elif per_device_train_batch_size > 8 and compute_specs == "high_end":
raise ValueError("Batch size must be less than 9. Higher batch sizes cause out of memory error.")
return per_device_train_batch_size
@field_validator("per_device_eval_batch_size")
def validate_per_device_eval_batch_size(cls, per_device_eval_batch_size):
Expand Down Expand Up @@ -236,11 +237,20 @@ def validate_dataset(cls, dataset):
if not dataset:
raise ValueError("Dataset cannot be empty.")
return dataset

@model_validator(mode='after')
def validate_batch_size_with_compute_specs(self):
"""Validate batch size based on compute specs"""
if self.per_device_train_batch_size > 3 and self.compute_specs != "high_end":
raise ValueError("Batch size must be 3 or less. Your device cannot support a higher batch size.")
elif self.per_device_train_batch_size > 8 and self.compute_specs == "high_end":
raise ValueError("Batch size must be 8 or less. Higher batch sizes cause out of memory error.")
return self


@router.get("/detect")
async def detect_hardware_page(request: Request) -> JSONResponse:
global_manager.clear_global_manager.settings_cache() # Clear the cache to ensure fresh detection
global_manager.clear_settings_cache() # Clear the cache to ensure fresh detection
return JSONResponse({
"app_name": global_manager.app_name,
"message": "Ready to detect hardware"
Expand Down Expand Up @@ -290,7 +300,7 @@ async def detect_hardware(request: Request) -> JSONResponse:
print(e)
raise HTTPException(
status_code=400,
detail="Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'."
detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
)
except Exception as e:
print("General exception triggered")
Expand All @@ -317,6 +327,14 @@ async def set_model(request: Request) -> None:

@router.post("/validate_custom_model", response_class=JSONResponse)
async def validate_custom_model(request: Request) -> JSONResponse:
"""
Validate a custom model from HuggingFace Hub.

Note: Currently validates repository existence but not task compatibility.
Consider adding architecture-task compatibility checks (e.g., ensure
summarization models aren't used for text generation tasks) for better
user experience and to prevent fine-tuning failures.
"""
try:
form = await request.json()
validation_data = CustomModelValidationData(repo_name=form["repo_name"])
Expand Down Expand Up @@ -424,15 +442,14 @@ async def load_settings(json_file: UploadFile = File(...), settings: str = Form(


def finetuning_task(llm_tuner) -> None:
output_dir = None
try:
llm_tuner.load_dataset(global_manager.settings_builder.dataset)
output_dir = llm_tuner.output_dir # Store for cleanup on failure
path = llm_tuner.finetune()

# Handle both absolute and relative paths
if os.path.isabs(path):
model_path = path
else:
model_path = os.path.join(os.path.dirname(__file__), path.replace("./", ""))
# Use the path returned from finetune (should be absolute)
model_path = os.path.abspath(path) if not os.path.isabs(path) else path

model_data = {
"model_name": global_manager.settings_builder.fine_tuned_name.split('/')[-1] if global_manager.settings_builder.fine_tuned_name else os.path.basename(model_path),
Expand All @@ -445,6 +462,17 @@ def finetuning_task(llm_tuner) -> None:
"is_custom_base_model": global_manager.settings_builder.is_custom_model
}
global_manager.db_manager.add_model(model_data)

except Exception as e:
print(f"Fine-tuning failed: {e}")
# Cleanup failed fine-tuning artifacts
if output_dir and os.path.exists(output_dir):
try:
shutil.rmtree(output_dir)
print(f"Cleaned up failed fine-tuning artifacts at: {output_dir}")
except Exception as cleanup_error:
print(f"Warning: Could not cleanup output directory: {cleanup_error}")
raise

finally:
global_manager.settings_cache.clear()
Expand Down Expand Up @@ -485,6 +513,19 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
status_code=400,
detail="A finetuning is already in progress. Please wait until it completes."
)

# Validate available disk space (require at least 10GB free)
try:
stat = shutil.disk_usage(global_manager.model_path)
available_gb = stat.free / (1024 ** 3) # Convert to GB
if available_gb < 10:
raise HTTPException(
status_code=400,
detail=f"Insufficient disk space. Available: {available_gb:.2f}GB. Required: at least 10GB."
)
except Exception as e:
print(f"Warning: Could not check disk space: {e}")

global_manager.finetuning_status["status"] = "initializing"
global_manager.finetuning_status["message"] = "Starting finetuning process..."
if global_manager.settings_builder.task == "text-generation":
Expand All @@ -508,7 +549,7 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
else:
raise HTTPException(
status_code=400,
detail="Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'."
detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
)

llm_tuner.set_settings(**global_manager.settings_builder.get_settings())
Expand Down
6 changes: 3 additions & 3 deletions ModelForge/routers/hub_management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ async def push_model_to_hub(request: Request) -> JSONResponse:
f"Please ensure your huggingface token grants you write access. "
f"If you are pushing to an organization, contact the administrator for write access."}, status_code=403)
except HfHubHTTPError as e:
return JSONResponse({f"error": "Failed to push model to HuggingFace Hub. "
f"Please check your network connection and authentication token."
return JSONResponse({"error": f"Failed to push model to HuggingFace Hub. "
f"Please check your network connection and authentication token. "
f"Error received is: {e}"}, status_code=500)
except Exception as e:
return JSONResponse({"Unknown error": str(e)}, status_code=500)
return JSONResponse({"error": str(e)}, status_code=500)
21 changes: 18 additions & 3 deletions ModelForge/routers/playground_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@

from ..globals.globals_instance import global_manager

from pydantic import BaseModel, field_validator

router = APIRouter(
prefix="/playground",
)

class PlaygroundRequest(BaseModel):
model_path: str

@field_validator("model_path")
def validate_model_path(cls, model_path):
if not model_path or not model_path.strip():
raise ValueError("Model path cannot be empty.")
return model_path.strip()

@router.get("/model_path")
async def get_model_path(request: Request) -> JSONResponse:
return JSONResponse({
Expand All @@ -23,20 +34,24 @@ async def get_model_path(request: Request) -> JSONResponse:
async def new_playground(request: Request) -> None:
form = await request.json()
print(form)
model_path = form["model_path"]
playground_request = PlaygroundRequest(model_path=form["model_path"])
model_path = playground_request.model_path

base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "utilities"))
chat_script = os.path.join(base_path, "chat_playground.py")
if os.name == "nt": # Windows
# Note: shell=True is required for 'start' command (cmd.exe built-in)
# Input is validated via PlaygroundRequest Pydantic model
command = ["cmd.exe", "/c", "start", "python", chat_script, "--model_path", model_path]
subprocess.Popen(command, shell=True)
else: # Unix/Linux/Mac
command = ["x-terminal-emulator", "-e", f"python {chat_script} --model_path {model_path}"]
# Use list format without string interpolation for security
command = ["x-terminal-emulator", "-e", "python", chat_script, "--model_path", model_path]
try:
subprocess.Popen(command)
except FileNotFoundError:
# Fallback to gnome-terminal or xterm if x-terminal-emulator is not available
try:
subprocess.Popen(["gnome-terminal", "--", "python3", chat_script, "--model_path", model_path])
except FileNotFoundError:
subprocess.Popen(["xterm", "-e", f"python3 {chat_script} --model_path {model_path}"])
subprocess.Popen(["xterm", "-e", "python3", chat_script, "--model_path", model_path])
32 changes: 8 additions & 24 deletions ModelForge/utilities/finetuning/Seq2SeqLMTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,14 @@ def format_example(example: dict, specs: str, keys=None) -> Dict | None:
if keys is None:
keys = ["article", "summary"]

if specs == "low_end":
return {
"text": f'''
["role": "system", "content": "You are a text summarization assistant."],
[role": "user", "content": {example[keys[0]]}],
["role": "assistant", "content": {example[keys[1]]}]
'''
}
elif specs == "mid_range":
return {
"text": f'''
["role": "system", "content": "You are a text summarization assistant."],
[role": "user", "content": {example[keys[0]]}],
["role": "assistant", "content": {example[keys[1]]}]
'''
}
elif specs == "high_end":
return {
"text": f'''
["role": "system", "content": "You are a text summarization assistant."],
[role": "user", "content": {example[keys[0]]}],
["role": "assistant", "content": {example[keys[1]]}]
'''
}
# Format is the same regardless of specs, so we can simplify
return {
"text": f'''
["role": "system", "content": "You are a text summarization assistant."],
["role": "user", "content": {example[keys[0]]}],
["role": "assistant", "content": {example[keys[1]]}]
'''
}

def load_dataset(self, dataset_path: str) -> None:
dataset = load_dataset("json", data_files=dataset_path, split="train")
Expand Down
8 changes: 8 additions & 0 deletions ModelForge/utilities/settings_managers/DBManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from typing import Any

class DatabaseManager:
"""
Manages SQLite database operations for ModelForge.

Note: Currently opens/closes connections for each operation. For better performance
in high-traffic scenarios, consider implementing connection pooling using libraries
like SQLAlchemy or maintaining a connection pool manually.
"""
_instance = None

def __new__(cls, *args, **kwargs):
Expand Down Expand Up @@ -153,4 +160,5 @@ def delete_model(self, model_id) -> bool:
def kill_connection(self) -> None:
if self.conn:
self.conn.close()
self.conn = None
del self.cursor
2 changes: 1 addition & 1 deletion ModelForge/utilities/settings_managers/FileManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def validate_or_create_file(cls, check_path: str) -> str:
return check_path

@classmethod
def save_file(cls, content:bytes, file_path: str) -> str | None:
def save_file(cls, file_path: str, content: bytes) -> str | None:
try:
file_dir = os.path.dirname(file_path)
cls.validate_or_create_dirs(os.path.abspath(file_dir))
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ dependencies = [
"setuptools==78.1.0",
"tensorboard==2.19.0",
"tensorboard-data-server==0.7.2",
"tokenizers==0.21.1",
"tokenizers==0.21.0",
"tqdm==4.67.1",
"transformers>=4.45.1",
"transformers==4.48.3",
"trl==0.16.0",
"uvicorn",
"platformdirs",
Expand Down
Loading
Loading