Skip to content

Commit 3745233

Browse files
committed
Add session manager
1 parent cf569eb commit 3745233

File tree

3 files changed

+98
-21
lines changed

3 files changed

+98
-21
lines changed

src/rest/rest.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
UploadFileRequest,
1212
DeleteFilesRequest,
1313
ListDirectoryResponse,
14-
CreateNamespaceRequest,
15-
AudioNamespaceProgressInitial,
1614
)
1715
from src.service.audio import AudioService
1816
from src.service.file import FileService
@@ -64,18 +62,9 @@ async def list_namespaces(self):
6462
namespaces = self.namespace_service.get_namespaces()
6563
return {"namespaces": namespaces}
6664

67-
async def new_namespace(self, new_namespace_request: CreateNamespaceRequest):
65+
async def new_namespace(self):
6866
"""Create a new namespace."""
69-
args = new_namespace_request.args
70-
if new_namespace_request.service_name == "audio":
71-
namespace = self.namespace_service.create_namespace(new_namespace_request.service_name, args)
72-
namespace.progress = AudioNamespaceProgressInitial()
73-
self.namespace_service.submit_namespace(namespace)
74-
audio_service = AudioService(args["source_dir"], args["output_dir"], namespace)
75-
audio_service.audio_service()
76-
return namespace
77-
78-
namespace = self.namespace_service.create_namespace(new_namespace_request.service_name, new_namespace_request.args)
67+
namespace = self.namespace_service.create_namespace()
7968
return namespace
8069

8170
async def change_namespace(self, namespace_id: str, update_request: UpdateNamespaceRequest):

src/service/namespace.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def filter_namespaces(self, fn: Callable[[Namespace], bool]) -> List[Namespace]:
3838
"""
3939
return sorted(list(filter(fn, self._namespaces.values())), key=lambda t: t.createdAt)
4040

41-
def create_namespace(self, service_name: str, args: dict) -> Namespace:
41+
def create_namespace(self) -> Namespace:
4242
"""Create a new namespace."""
4343
namespace_id = str(uuid.uuid4())
4444
namespace_name = f"Namespace-{namespace_id[:8]}"
@@ -51,17 +51,10 @@ def create_namespace(self, service_name: str, args: dict) -> Namespace:
5151
name=namespace_name,
5252
createdAt=created_at,
5353
homePath=home_path,
54-
service_name=service_name,
55-
args=args,
56-
progress=Progress(),
5754
)
5855
self._namespaces[namespace_id] = namespace
5956
return namespace
6057

61-
def submit_namespace(self, namespace: Namespace):
62-
self._save_namespace_metadata(namespace)
63-
self._namespaces[namespace.namespaceID] = namespace
64-
6558
def get_namespaces(self) -> List[Namespace]:
6659
"""Get all namespaces."""
6760
namespaces = []

src/service/session.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import threading
2+
from functools import wraps
3+
from typing import Optional, Dict, Any
4+
from enum import Enum
5+
6+
class Status(Enum):
7+
RUNNING = "Running"
8+
COMPLETED = "Completed"
9+
FAILED = "Failed"
10+
11+
class SessionManager:
12+
"""Manages training session, ensuring single GPU task execution and tracking task state."""
13+
14+
_instance = None
15+
_lock = threading.Lock()
16+
17+
def __new__(cls):
18+
"""Singleton pattern to ensure only one instance of SessionManager exists."""
19+
if not cls._instance:
20+
with cls._lock:
21+
if not cls._instance:
22+
cls._instance = super(SessionManager, cls).__new__(cls)
23+
cls._instance.current_session = None
24+
return cls._instance
25+
26+
def start_session(self, task_name: str):
27+
"""Attempts to start a new session; rejects if another task is already running."""
28+
if self.current_session and self.current_session["status"] == Status.RUNNING:
29+
raise RuntimeError(f"A task '{self.current_session['task_name']}' is already running. Cannot submit another task!")
30+
31+
self.current_session = {
32+
"task_name": task_name,
33+
"status": Status.RUNNING,
34+
"error": None, # Stores error details if task fails
35+
}
36+
37+
def end_session(self, result: Any):
38+
"""Marks task as completed successfully."""
39+
if self.current_session:
40+
self.current_session["status"] = Status.COMPLETED
41+
self.current_session["result"] = result
42+
43+
def fail_session(self, error: str):
44+
"""Marks task as failed and stores error information."""
45+
if self.current_session:
46+
self.current_session["status"] = Status.FAILED
47+
self.current_session["error"] = error
48+
49+
def update_session_info(self, info: Dict[str, Any]):
50+
"""Updates task session with arbitrary info."""
51+
if not self.current_session or self.current_session["status"] != Status.RUNNING:
52+
raise RuntimeError("No active task to update session info!")
53+
54+
self.current_session.update(info)
55+
def get_session_info(self) -> Optional[Dict[str, Any]]:
56+
"""Returns current task state information."""
57+
return self.current_session
58+
59+
# Decorator to wrap task execution logic
60+
def session_guard(task_name: str):
61+
"""Ensures tasks are managed within SessionManager and handles failure states."""
62+
def decorator(func):
63+
@wraps(func)
64+
def wrapper(*args, **kwargs):
65+
session_manager = SessionManager()
66+
67+
try:
68+
session_manager.start_session(task_name)
69+
result = func(*args, **kwargs) # Execute the training task
70+
session_manager.end_session(result)
71+
return result
72+
except Exception as e:
73+
session_manager.fail_session(str(e)) # Record failure details
74+
# NOTICE: No Re-raise exception here,
75+
# as we capture the error and record it in session info.
76+
# raise e
77+
return wrapper
78+
return decorator
79+
80+
# Example for using SessionManager and session_guard decorator.
81+
@session_guard("TrainingModel")
82+
def train_model():
83+
session_manager = SessionManager()
84+
85+
for epoch in range(1, 6):
86+
if epoch == 3: # Simulate task failure
87+
raise RuntimeError("Error occurred at epoch 3!")
88+
89+
session_manager.update_session_info({
90+
"progress": epoch / 5,
91+
"loss": 0.05 * (6 - epoch),
92+
"epoch": epoch,
93+
})
94+
95+
return "Training Completed"

0 commit comments

Comments
 (0)