Skip to content

Commit 559154a

Browse files
committed
fix: implement generation progress syncing and MPS memory tracking
1 parent 385d280 commit 559154a

File tree

4 files changed

+79
-42
lines changed

4 files changed

+79
-42
lines changed

backend/app_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ def __init__(
162162
device=config.device,
163163
)
164164

165-
self.generation = GenerationHandler(state=self.state, lock=self._lock)
165+
from state.job_queue import JobQueue
166+
self.job_queue = JobQueue(persistence_path=config.settings_file.parent / "job_queue.json")
167+
168+
self.generation = GenerationHandler(state=self.state, lock=self._lock, job_queue=self.job_queue)
166169

167170
self.video_generation = VideoGenerationHandler(
168171
state=self.state,
@@ -243,9 +246,6 @@ def __init__(
243246

244247
self.downloads.cleanup_downloading_dir()
245248

246-
from state.job_queue import JobQueue
247-
self.job_queue = JobQueue(persistence_path=config.settings_file.parent / "job_queue.json")
248-
249249
# Wire up the QueueWorker with concrete executors so submitted jobs
250250
# are dispatched to the appropriate generation handler.
251251
from handlers.job_executors import ApiJobExecutor, GpuJobExecutor

backend/handlers/generation_handler.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Literal
6+
from typing import TYPE_CHECKING, Literal
77

88
from api_types import CancelResponse, GenerationProgressResponse
99
from handlers.base import StateHandlerBase, with_state_lock
@@ -17,11 +17,23 @@
1717
GpuSlot,
1818
)
1919

20+
if TYPE_CHECKING:
21+
from threading import RLock
22+
from state.app_state_types import AppState
23+
from state.job_queue import JobQueue
24+
2025
logger = logging.getLogger(__name__)
2126
GenerationSlot = Literal["gpu", "api"]
2227

23-
2428
class GenerationHandler(StateHandlerBase):
29+
def __init__(self, state: AppState, lock: RLock, job_queue: JobQueue | None = None) -> None:
30+
super().__init__(state, lock)
31+
self._job_queue = job_queue
32+
self._current_job_id: str | None = None
33+
34+
def set_current_job_id(self, job_id: str | None) -> None:
35+
self._current_job_id = job_id or None
36+
2537
@with_state_lock
2638
def start_generation(self, generation_id: str) -> None:
2739
if self.is_generation_running():
@@ -91,6 +103,14 @@ def update_progress(
91103
current_step: int | None = None,
92104
total_steps: int | None = None,
93105
) -> None:
106+
# Sync to the persistent JobQueue if we have a job context
107+
if self._job_queue and self._current_job_id:
108+
self._job_queue.update_job(
109+
self._current_job_id,
110+
progress=progress,
111+
phase=phase,
112+
)
113+
94114
match self._running_slot():
95115
case "gpu":
96116
match self.state.gpu_slot:

backend/handlers/job_executors.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@ def __init__(
5454

5555
def execute(self, job: QueueJob) -> list[str]:
5656
logger.info("[QueueWorker] Executing GPU job %s (type=%s model=%s)", job.id, job.type, job.model)
57-
if job.type == "image":
58-
return self._execute_image(job)
59-
return self._execute_video(job)
57+
# Pass the queue job ID to the generation handler so it can sync progress
58+
gen_handler = self._video._generation
59+
gen_handler.set_current_job_id(job.id)
60+
try:
61+
if job.type == "image":
62+
return self._execute_image(job)
63+
return self._execute_video(job)
64+
finally:
65+
gen_handler.set_current_job_id(None)
6066

6167
def _execute_video(self, job: QueueJob) -> list[str]:
6268
p = job.params
@@ -104,35 +110,40 @@ def __init__(
104110
self._image = image_generation
105111

106112
def execute(self, job: QueueJob) -> list[str]:
107-
logger.info("[QueueWorker] Executing API job %s (type=%s model=%s)", job.id, job.type, job.model)
108-
if job.type == "image":
109-
p = job.params
110-
req = GenerateImageRequest(
111-
prompt=_str(p, "prompt"),
112-
width=_int(p, "width", 1920),
113-
height=_int(p, "height", 1080),
114-
numImages=_int(p, "numImages", 1),
115-
numSteps=_int(p, "numSteps", 4),
116-
)
117-
result = self._image.generate(req)
118-
return list(result.image_paths or [])
119-
else:
120-
p = job.params
121-
req = GenerateVideoRequest(
122-
prompt=_str(p, "prompt"),
123-
imagePath=_str(p, "imagePath") or None,
124-
lastFramePath=_str(p, "lastFramePath") or None,
125-
audioPath=_str(p, "audioPath") or None,
126-
resolution=_str(p, "resolution", "540p"),
127-
duration=_str(p, "duration", "5"),
128-
fps=_str(p, "fps", "24"),
129-
audio=_str(p, "audio", "false"),
130-
cameraMotion=_camera_motion(p),
131-
aspectRatio=_aspect_ratio(p),
132-
model=job.model,
133-
negativePrompt=_str(p, "negativePrompt"),
134-
)
135-
result = self._video.generate(req)
136-
if result.video_path:
137-
return [result.video_path]
138-
return []
113+
logger.info("[QueueWorker] Executing API job %s (type=%s model=%s)", job.id, job.id, job.type, job.model)
114+
gen_handler = self._video._generation
115+
gen_handler.set_current_job_id(job.id)
116+
try:
117+
if job.type == "image":
118+
p = job.params
119+
req = GenerateImageRequest(
120+
prompt=_str(p, "prompt"),
121+
width=_int(p, "width", 1920),
122+
height=_int(p, "height", 1080),
123+
numImages=_int(p, "numImages", 1),
124+
numSteps=_int(p, "numSteps", 4),
125+
)
126+
result = self._image.generate(req)
127+
return list(result.image_paths or [])
128+
else:
129+
p = job.params
130+
req = GenerateVideoRequest(
131+
prompt=_str(p, "prompt"),
132+
imagePath=_str(p, "imagePath") or None,
133+
lastFramePath=_str(p, "lastFramePath") or None,
134+
audioPath=_str(p, "audioPath") or None,
135+
resolution=_str(p, "resolution", "540p"),
136+
duration=_str(p, "duration", "5"),
137+
fps=_str(p, "fps", "24"),
138+
audio=_str(p, "audio", "false"),
139+
cameraMotion=_camera_motion(p),
140+
aspectRatio=_aspect_ratio(p),
141+
model=job.model,
142+
negativePrompt=_str(p, "negativePrompt"),
143+
)
144+
result = self._video.generate(req)
145+
if result.video_path:
146+
return [result.video_path]
147+
return []
148+
finally:
149+
gen_handler.set_current_job_id(None)

backend/services/gpu_info/gpu_info_impl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,16 @@ def get_gpu_info(self) -> GpuTelemetryPayload:
7979
if self.get_mps_available():
8080
chip = self._get_macos_chip_name()
8181
name = f"{chip} (MPS)" if chip else "Apple Silicon (MPS)"
82+
vram_used = 0
83+
try:
84+
# torch.mps.current_allocated_memory() provides allocated memory in bytes
85+
vram_used = torch.mps.current_allocated_memory() // (1024 * 1024)
86+
except Exception:
87+
pass
8288
return {
8389
"name": name,
8490
"vram": self._get_system_ram_mb(),
85-
"vramUsed": 0,
91+
"vramUsed": vram_used,
8692
}
8793

8894
return {"name": "Unknown", "vram": 0, "vramUsed": 0}

0 commit comments

Comments
 (0)