Skip to content

Commit 369f077

Browse files
CopilotRETR0-OSCopilot
authored
Fix 20 critical bugs: broken validators, resource leaks, and security issues (#42)
* Initial plan * Phase 1: Fix critical bugs (Issues #1, #4, #5, #6) Co-authored-by: RETR0-OS <[email protected]> * Phase 2: Data validation and consistency fixes (Issues #2, #3, #13) Co-authored-by: RETR0-OS <[email protected]> * Phase 3: Security and resource management fixes (Issues #8, #9, #17, #18) Co-authored-by: RETR0-OS <[email protected]> * Phase 4: Architectural improvements (Issues #10-#20) Co-authored-by: RETR0-OS <[email protected]> * Address code review feedback Co-authored-by: RETR0-OS <[email protected]> * Fix batch size error messages to be clearer Co-authored-by: RETR0-OS <[email protected]> * resolve comments * Update ModelForge/routers/finetuning_router.py Co-authored-by: Copilot <[email protected]> * Extract task validation to constants for maintainability Co-authored-by: RETR0-OS <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: RETR0-OS <[email protected]> Co-authored-by: RETR0-OS <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent c24abf0 commit 369f077

File tree

10 files changed

+144
-86
lines changed

10 files changed

+144
-86
lines changed

ModelForge/app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
## Static files
2424
frontend_dir = os.path.join(os.path.dirname(__file__), "./Frontend/build")
2525
app_name = "ModelForge"
26-
origins = [
27-
"http://localhost:8000",
28-
]
26+
27+
# CORS origins - configurable via environment variable
28+
cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:8000")
29+
origins = [origin.strip() for origin in cors_origins_env.split(",")]
2930

3031
app.add_middleware(
3132
CORSMiddleware,

ModelForge/globals/globals.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class GlobalSettings:
99
_instance = None
10+
_initialized = False
1011
file_manager = None
1112
hardware_detector = None
1213
settings_builder = None
@@ -23,19 +24,27 @@ def __new__(cls):
2324
return cls._instance
2425

2526
def __init__(self):
26-
self.file_manager = FileManager()
27-
self.hardware_detector = HardwareDetector()
28-
self.settings_builder = SettingsBuilder(None, None, None)
29-
self.settings_cache = {}
30-
self.finetuning_status = {"status": "idle", "progress": 0, "message": ""}
31-
self.datasets_dir = self.file_manager.return_default_dirs()["datasets"]
32-
self.model_path = self.file_manager.return_default_dirs()["models"]
33-
self.db_manager = DatabaseManager(db_path=os.path.join(self.file_manager.return_default_dirs()["database"], "modelforge.sqlite"))
34-
self.app_name = "ModelForge"
27+
# Only initialize once to maintain singleton behavior
28+
if not GlobalSettings._initialized:
29+
self.file_manager = FileManager()
30+
self.hardware_detector = HardwareDetector()
31+
self.settings_builder = SettingsBuilder(None, None, None)
32+
self.settings_cache = {}
33+
# NOTE: finetuning_status is accessed from multiple places (callback, background task)
34+
# without locking. Python's GIL provides basic protection, but be cautious with
35+
# complex operations. Consider using threading.Lock if race conditions occur.
36+
self.finetuning_status = {"status": "idle", "progress": 0, "message": ""}
37+
self.datasets_dir = self.file_manager.return_default_dirs()["datasets"]
38+
self.model_path = self.file_manager.return_default_dirs()["models"]
39+
self.db_manager = DatabaseManager(db_path=os.path.join(self.file_manager.return_default_dirs()["database"], "modelforge.sqlite"))
40+
self.app_name = "ModelForge"
41+
GlobalSettings._initialized = True
3542

3643
@classmethod
3744
def get_instance(cls):
38-
return cls.__new__(cls)
45+
if cls._instance is None:
46+
cls._instance = cls()
47+
return cls._instance
3948

4049
def clear_settings_cache(self):
4150
self.settings_cache.clear()

ModelForge/routers/finetuning_router.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from datetime import datetime
22
import json
33
import os
4+
import shutil
45
import traceback
56
import uuid
67

78
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
89
from fastapi import Request
910
from starlette.responses import JSONResponse
10-
from pydantic import BaseModel, field_validator
11+
from pydantic import BaseModel, field_validator, model_validator
1112

1213
from ..utilities.finetuning.CausalLLMTuner import CausalLLMFinetuner
1314
from ..utilities.finetuning.QuestionAnsweringTuner import QuestionAnsweringTuner
@@ -20,13 +21,17 @@
2021
prefix="/finetune",
2122
)
2223

24+
# Valid task types
25+
VALID_TASKS = ["text-generation", "summarization", "extractive-question-answering"]
26+
VALID_TASKS_STR = "'text-generation', 'summarization', or 'extractive-question-answering'"
27+
2328
## Pydantic Data Validator Classes
2429
class TaskFormData(BaseModel):
2530
task: str
2631
@field_validator("task")
2732
def validate_task(cls, task):
28-
if task not in ["text-generation", "summarization", "extractive-question-answering"]:
29-
raise ValueError("Invalid task. Must be one of 'text-generation', 'summarization', or 'extractive-question-answering'.")
33+
if task not in VALID_TASKS:
34+
raise ValueError(f"Invalid task. Must be one of {VALID_TASKS_STR}.")
3035
return task
3136

3237
class SelectedModelFormData(BaseModel):
@@ -83,8 +88,8 @@ def validate_dataset_prescence(cls, dataset):
8388
return dataset
8489
@field_validator("task")
8590
def validate_task(cls, task):
86-
if task not in ["text-generation", "summarization", "question-answering"]:
87-
raise ValueError("Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'.")
91+
if task not in VALID_TASKS:
92+
raise ValueError(f"Invalid task. Must be one of {VALID_TASKS_STR}.")
8893
return task
8994
@field_validator("model_name")
9095
def validate_model_name(cls, model_name):
@@ -95,7 +100,7 @@ def validate_model_name(cls, model_name):
95100
def validate_num_train_epochs(cls, num_train_epochs):
96101
if num_train_epochs <= 0:
97102
raise ValueError("Number of training epochs must be greater than 0.")
98-
elif num_train_epochs > 30:
103+
if num_train_epochs >= 50:
99104
raise ValueError("Number of training epochs must be less than 50.")
100105
return num_train_epochs
101106
@field_validator("compute_specs")
@@ -154,13 +159,9 @@ def validate_bf16(cls, bf16):
154159
raise ValueError("bf16 must be true or false.")
155160
return bf16
156161
@field_validator("per_device_train_batch_size")
157-
def validate_per_device_train_batch_size(cls, per_device_train_batch_size, compute_specs):
162+
def validate_per_device_train_batch_size(cls, per_device_train_batch_size):
158163
if per_device_train_batch_size <= 0:
159164
raise ValueError("Batch size must be greater than 0.")
160-
elif per_device_train_batch_size > 3 and compute_specs != "high_end":
161-
raise ValueError("Batch size must be less than 4. Your device cannot support a higher batch size.")
162-
elif per_device_train_batch_size > 8 and compute_specs == "high_end":
163-
raise ValueError("Batch size must be less than 9. Higher batch sizes cause out of memory error.")
164165
return per_device_train_batch_size
165166
@field_validator("per_device_eval_batch_size")
166167
def validate_per_device_eval_batch_size(cls, per_device_eval_batch_size):
@@ -236,11 +237,20 @@ def validate_dataset(cls, dataset):
236237
if not dataset:
237238
raise ValueError("Dataset cannot be empty.")
238239
return dataset
240+
241+
@model_validator(mode='after')
242+
def validate_batch_size_with_compute_specs(self):
243+
"""Validate batch size based on compute specs"""
244+
if self.per_device_train_batch_size > 3 and self.compute_specs != "high_end":
245+
raise ValueError("Batch size must be 3 or less. Your device cannot support a higher batch size.")
246+
elif self.per_device_train_batch_size > 8 and self.compute_specs == "high_end":
247+
raise ValueError("Batch size must be 8 or less. Higher batch sizes cause out of memory error.")
248+
return self
239249

240250

241251
@router.get("/detect")
242252
async def detect_hardware_page(request: Request) -> JSONResponse:
243-
global_manager.clear_global_manager.settings_cache() # Clear the cache to ensure fresh detection
253+
global_manager.clear_settings_cache() # Clear the cache to ensure fresh detection
244254
return JSONResponse({
245255
"app_name": global_manager.app_name,
246256
"message": "Ready to detect hardware"
@@ -290,7 +300,7 @@ async def detect_hardware(request: Request) -> JSONResponse:
290300
print(e)
291301
raise HTTPException(
292302
status_code=400,
293-
detail="Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'."
303+
detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
294304
)
295305
except Exception as e:
296306
print("General exception triggered")
@@ -317,6 +327,14 @@ async def set_model(request: Request) -> None:
317327

318328
@router.post("/validate_custom_model", response_class=JSONResponse)
319329
async def validate_custom_model(request: Request) -> JSONResponse:
330+
"""
331+
Validate a custom model from HuggingFace Hub.
332+
333+
Note: Currently validates repository existence but not task compatibility.
334+
Consider adding architecture-task compatibility checks (e.g., ensure
335+
summarization models aren't used for text generation tasks) for better
336+
user experience and to prevent fine-tuning failures.
337+
"""
320338
try:
321339
form = await request.json()
322340
validation_data = CustomModelValidationData(repo_name=form["repo_name"])
@@ -424,15 +442,14 @@ async def load_settings(json_file: UploadFile = File(...), settings: str = Form(
424442

425443

426444
def finetuning_task(llm_tuner) -> None:
445+
output_dir = None
427446
try:
428447
llm_tuner.load_dataset(global_manager.settings_builder.dataset)
448+
output_dir = llm_tuner.output_dir # Store for cleanup on failure
429449
path = llm_tuner.finetune()
430450

431-
# Handle both absolute and relative paths
432-
if os.path.isabs(path):
433-
model_path = path
434-
else:
435-
model_path = os.path.join(os.path.dirname(__file__), path.replace("./", ""))
451+
# Use the path returned from finetune (should be absolute)
452+
model_path = os.path.abspath(path) if not os.path.isabs(path) else path
436453

437454
model_data = {
438455
"model_name": global_manager.settings_builder.fine_tuned_name.split('/')[-1] if global_manager.settings_builder.fine_tuned_name else os.path.basename(model_path),
@@ -445,6 +462,17 @@ def finetuning_task(llm_tuner) -> None:
445462
"is_custom_base_model": global_manager.settings_builder.is_custom_model
446463
}
447464
global_manager.db_manager.add_model(model_data)
465+
466+
except Exception as e:
467+
print(f"Fine-tuning failed: {e}")
468+
# Cleanup failed fine-tuning artifacts
469+
if output_dir and os.path.exists(output_dir):
470+
try:
471+
shutil.rmtree(output_dir)
472+
print(f"Cleaned up failed fine-tuning artifacts at: {output_dir}")
473+
except Exception as cleanup_error:
474+
print(f"Warning: Could not cleanup output directory: {cleanup_error}")
475+
raise
448476

449477
finally:
450478
global_manager.settings_cache.clear()
@@ -485,6 +513,19 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
485513
status_code=400,
486514
detail="A finetuning is already in progress. Please wait until it completes."
487515
)
516+
517+
# Validate available disk space (require at least 10GB free)
518+
try:
519+
stat = shutil.disk_usage(global_manager.model_path)
520+
available_gb = stat.free / (1024 ** 3) # Convert to GB
521+
if available_gb < 10:
522+
raise HTTPException(
523+
status_code=400,
524+
detail=f"Insufficient disk space. Available: {available_gb:.2f}GB. Required: at least 10GB."
525+
)
526+
except Exception as e:
527+
print(f"Warning: Could not check disk space: {e}")
528+
488529
global_manager.finetuning_status["status"] = "initializing"
489530
global_manager.finetuning_status["message"] = "Starting finetuning process..."
490531
if global_manager.settings_builder.task == "text-generation":
@@ -508,7 +549,7 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
508549
else:
509550
raise HTTPException(
510551
status_code=400,
511-
detail="Invalid task. Must be one of 'text-generation', 'summarization', or 'question-answering'."
552+
detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
512553
)
513554

514555
llm_tuner.set_settings(**global_manager.settings_builder.get_settings())

ModelForge/routers/hub_management_router.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ async def push_model_to_hub(request: Request) -> JSONResponse:
7272
f"Please ensure your huggingface token grants you write access. "
7373
f"If you are pushing to an organization, contact the administrator for write access."}, status_code=403)
7474
except HfHubHTTPError as e:
75-
return JSONResponse({f"error": "Failed to push model to HuggingFace Hub. "
76-
f"Please check your network connection and authentication token."
75+
return JSONResponse({"error": f"Failed to push model to HuggingFace Hub. "
76+
f"Please check your network connection and authentication token. "
7777
f"Error received is: {e}"}, status_code=500)
7878
except Exception as e:
79-
return JSONResponse({"Unknown error": str(e)}, status_code=500)
79+
return JSONResponse({"error": str(e)}, status_code=500)

ModelForge/routers/playground_router.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,21 @@
77

88
from ..globals.globals_instance import global_manager
99

10+
from pydantic import BaseModel, field_validator
11+
1012
router = APIRouter(
1113
prefix="/playground",
1214
)
1315

16+
class PlaygroundRequest(BaseModel):
17+
model_path: str
18+
19+
@field_validator("model_path")
20+
def validate_model_path(cls, model_path):
21+
if not model_path or not model_path.strip():
22+
raise ValueError("Model path cannot be empty.")
23+
return model_path.strip()
24+
1425
@router.get("/model_path")
1526
async def get_model_path(request: Request) -> JSONResponse:
1627
return JSONResponse({
@@ -23,20 +34,24 @@ async def get_model_path(request: Request) -> JSONResponse:
2334
async def new_playground(request: Request) -> None:
2435
form = await request.json()
2536
print(form)
26-
model_path = form["model_path"]
37+
playground_request = PlaygroundRequest(model_path=form["model_path"])
38+
model_path = playground_request.model_path
2739

2840
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "utilities"))
2941
chat_script = os.path.join(base_path, "chat_playground.py")
3042
if os.name == "nt": # Windows
43+
# Note: shell=True is required for 'start' command (cmd.exe built-in)
44+
# Input is validated via PlaygroundRequest Pydantic model
3145
command = ["cmd.exe", "/c", "start", "python", chat_script, "--model_path", model_path]
3246
subprocess.Popen(command, shell=True)
3347
else: # Unix/Linux/Mac
34-
command = ["x-terminal-emulator", "-e", f"python {chat_script} --model_path {model_path}"]
48+
# Use list format without string interpolation for security
49+
command = ["x-terminal-emulator", "-e", "python", chat_script, "--model_path", model_path]
3550
try:
3651
subprocess.Popen(command)
3752
except FileNotFoundError:
3853
# Fallback to gnome-terminal or xterm if x-terminal-emulator is not available
3954
try:
4055
subprocess.Popen(["gnome-terminal", "--", "python3", chat_script, "--model_path", model_path])
4156
except FileNotFoundError:
42-
subprocess.Popen(["xterm", "-e", f"python3 {chat_script} --model_path {model_path}"])
57+
subprocess.Popen(["xterm", "-e", "python3", chat_script, "--model_path", model_path])

ModelForge/utilities/finetuning/Seq2SeqLMTuner.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,14 @@ def format_example(example: dict, specs: str, keys=None) -> Dict | None:
2020
if keys is None:
2121
keys = ["article", "summary"]
2222

23-
if specs == "low_end":
24-
return {
25-
"text": f'''
26-
["role": "system", "content": "You are a text summarization assistant."],
27-
[role": "user", "content": {example[keys[0]]}],
28-
["role": "assistant", "content": {example[keys[1]]}]
29-
'''
30-
}
31-
elif specs == "mid_range":
32-
return {
33-
"text": f'''
34-
["role": "system", "content": "You are a text summarization assistant."],
35-
[role": "user", "content": {example[keys[0]]}],
36-
["role": "assistant", "content": {example[keys[1]]}]
37-
'''
38-
}
39-
elif specs == "high_end":
40-
return {
41-
"text": f'''
42-
["role": "system", "content": "You are a text summarization assistant."],
43-
[role": "user", "content": {example[keys[0]]}],
44-
["role": "assistant", "content": {example[keys[1]]}]
45-
'''
46-
}
23+
# Format is the same regardless of specs, so we can simplify
24+
return {
25+
"text": f'''
26+
["role": "system", "content": "You are a text summarization assistant."],
27+
["role": "user", "content": {example[keys[0]]}],
28+
["role": "assistant", "content": {example[keys[1]]}]
29+
'''
30+
}
4731

4832
def load_dataset(self, dataset_path: str) -> None:
4933
dataset = load_dataset("json", data_files=dataset_path, split="train")

ModelForge/utilities/settings_managers/DBManager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from typing import Any
77

88
class DatabaseManager:
9+
"""
10+
Manages SQLite database operations for ModelForge.
11+
12+
Note: Currently opens/closes connections for each operation. For better performance
13+
in high-traffic scenarios, consider implementing connection pooling using libraries
14+
like SQLAlchemy or maintaining a connection pool manually.
15+
"""
916
_instance = None
1017

1118
def __new__(cls, *args, **kwargs):
@@ -153,4 +160,5 @@ def delete_model(self, model_id) -> bool:
153160
def kill_connection(self) -> None:
154161
if self.conn:
155162
self.conn.close()
163+
self.conn = None
156164
del self.cursor

ModelForge/utilities/settings_managers/FileManager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def validate_or_create_file(cls, check_path: str) -> str:
5050
return check_path
5151

5252
@classmethod
53-
def save_file(cls, content:bytes, file_path: str) -> str | None:
53+
def save_file(cls, file_path: str, content: bytes) -> str | None:
5454
try:
5555
file_dir = os.path.dirname(file_path)
5656
cls.validate_or_create_dirs(os.path.abspath(file_dir))

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ dependencies = [
1818
"setuptools==78.1.0",
1919
"tensorboard==2.19.0",
2020
"tensorboard-data-server==0.7.2",
21-
"tokenizers==0.21.1",
21+
"tokenizers==0.21.0",
2222
"tqdm==4.67.1",
23-
"transformers>=4.45.1",
23+
"transformers==4.48.3",
2424
"trl==0.16.0",
2525
"uvicorn",
2626
"platformdirs",

0 commit comments

Comments
 (0)