Skip to content

Commit 5b29260

Browse files
authored
feat(pipeline): ensure warmup uses generated frames instead of input frames (#544)
* feat(pipeline): enhance warmup handling and default workflow application - Added `ensure_warmup` method to the Pipeline class to manage warmup state based on resolution. - Integrated warmup checks in the ComfyStreamFrameProcessor to ensure the pipeline is ready before processing streams. - Implemented default workflow application when no prompts are provided during stream initialization. - Improved error handling for warmup failures and streamlined frame processing logic. * feat(frame_processor): add loading overlay management during stream processing to avoid blocking frame processor - Implemented methods to toggle and manage a loading overlay in the ComfyStreamFrameProcessor. - Added functionality to disable the overlay after pipeline ingest resumes. - Enhanced error handling for loading overlay state updates. - Integrated loading overlay management into the stream start process to improve user experience. * fix(frame_processor) fix imports after merge * bump pytrickle * bump pytrickle dependency to version 0.1.6 in pyproject.toml * bump pytrickle dependency to version 0.1.7 in pyproject.toml
1 parent 06b2a27 commit 5b29260

File tree

3 files changed

+151
-14
lines changed

3 files changed

+151
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
[project.optional-dependencies]
2323
dev = ["pytest", "pytest-cov", "ruff"]
2424
server = [
25-
"pytrickle @ git+https://github.com/livepeer/[email protected].5"
25+
"pytrickle @ git+https://github.com/livepeer/[email protected].7"
2626
]
2727

2828
[project.urls]

server/frame_processor.py

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
from comfystream.pipeline import Pipeline
1313
from comfystream.pipeline_state import PipelineState
14-
from comfystream.utils import convert_prompt
14+
from comfystream.utils import (
15+
convert_prompt,
16+
get_default_workflow,
17+
)
1518

1619
logger = logging.getLogger(__name__)
1720

@@ -151,6 +154,49 @@ async def _forward_text_loop():
151154
except Exception:
152155
logger.warning("Failed to set up text monitoring", exc_info=True)
153156

157+
def _set_loading_overlay(self, enabled: bool) -> bool:
158+
"""Toggle the StreamProcessor loading overlay if available."""
159+
processor = self._stream_processor
160+
if not processor:
161+
return False
162+
try:
163+
processor.set_loading_overlay(enabled)
164+
logger.debug("Set loading overlay to %s", enabled)
165+
return True
166+
except Exception:
167+
logger.warning("Failed to update loading overlay state", exc_info=True)
168+
return False
169+
170+
def _schedule_overlay_reset_on_ingest_enabled(self) -> None:
171+
"""Disable the loading overlay after pipeline ingest resumes."""
172+
if not self.pipeline:
173+
self._set_loading_overlay(False)
174+
return
175+
176+
if self.pipeline.is_ingest_enabled():
177+
self._set_loading_overlay(False)
178+
return
179+
180+
async def _wait_for_ingest_enable():
181+
try:
182+
while True:
183+
if self._stop_event.is_set():
184+
break
185+
if not self.pipeline:
186+
break
187+
if self.pipeline.is_ingest_enabled():
188+
break
189+
await asyncio.sleep(0.05)
190+
except asyncio.CancelledError:
191+
raise
192+
except Exception:
193+
logger.debug("Loading overlay watcher error", exc_info=True)
194+
finally:
195+
self._set_loading_overlay(False)
196+
197+
task = asyncio.create_task(_wait_for_ingest_enable())
198+
self._background_tasks.append(task)
199+
154200
async def _stop_text_forwarder(self) -> None:
155201
"""Stop the background text forwarder task if running."""
156202
task = self._text_forward_task
@@ -212,16 +258,27 @@ async def on_stream_start(self, params: Optional[Dict[str, Any]] = None):
212258
logger.info("Stream starting")
213259
self._reset_stop_event()
214260
logger.info(f"Stream start params: {params}")
261+
overlay_managed = False
215262

216263
if not self.pipeline:
217264
logger.debug("Stream start requested before pipeline initialization")
218265
return
219266

220267
stream_params = normalize_stream_params(params)
268+
stream_width = stream_params.get("width")
269+
stream_height = stream_params.get("height")
270+
stream_width = int(stream_width) if stream_width is not None else None
271+
stream_height = int(stream_height) if stream_height is not None else None
221272
prompt_payload = stream_params.pop("prompts", None)
222273
if prompt_payload is None:
223274
prompt_payload = stream_params.pop("prompt", None)
224275

276+
if not prompt_payload and not self.pipeline.state_manager.is_initialized():
277+
logger.info(
278+
"No prompts provided for new stream; applying default workflow for initialization"
279+
)
280+
prompt_payload = get_default_workflow()
281+
225282
if prompt_payload:
226283
try:
227284
await self._apply_stream_start_prompt(prompt_payload)
@@ -240,6 +297,19 @@ async def on_stream_start(self, params: Optional[Dict[str, Any]] = None):
240297
logger.exception("Failed to process stream start parameters")
241298
return
242299

300+
overlay_managed = self._set_loading_overlay(True)
301+
302+
try:
303+
await self.pipeline.ensure_warmup(stream_width, stream_height)
304+
except Exception:
305+
if overlay_managed:
306+
self._set_loading_overlay(False)
307+
logger.exception("Failed to ensure pipeline warmup during stream start")
308+
return
309+
310+
if overlay_managed:
311+
self._schedule_overlay_reset_on_ingest_enabled()
312+
243313
try:
244314
if (
245315
self.pipeline.state != PipelineState.STREAMING
@@ -312,6 +382,12 @@ async def process_video_async(
312382
if not self.pipeline:
313383
return frame
314384

385+
# TODO: Do we really need this here?
386+
await self.pipeline.ensure_warmup()
387+
388+
if not self.pipeline.state_manager.is_initialized():
389+
return VideoProcessingResult.WITHHELD
390+
315391
# If pipeline ingestion is paused, withhold frame so pytrickle renders the overlay
316392
if not self.pipeline.is_ingest_enabled():
317393
return VideoProcessingResult.WITHHELD
@@ -324,18 +400,9 @@ async def process_video_async(
324400
# Process through pipeline
325401
await self.pipeline.put_video_frame(av_frame)
326402

327-
# Try to get processed frame with short timeout
328-
try:
329-
processed_av_frame = await asyncio.wait_for(
330-
self.pipeline.get_processed_video_frame(),
331-
timeout=self._stream_processor.overlay_config.auto_timeout_seconds,
332-
)
333-
processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame)
334-
return processed_frame
335-
336-
except asyncio.TimeoutError:
337-
# No frame ready yet - return withheld sentinel to trigger overlay
338-
return VideoProcessingResult.WITHHELD
403+
processed_av_frame = await self.pipeline.get_processed_video_frame()
404+
processed_frame = VideoFrame.from_av_frame_with_timing(processed_av_frame, frame)
405+
return processed_frame
339406

340407
except Exception as e:
341408
logger.error(f"Video processing failed: {e}")

src/comfystream/pipeline.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def __init__(
7272
self._initialize_lock = asyncio.Lock()
7373
self._ingest_enabled = True
7474
self._prompt_update_lock = asyncio.Lock()
75+
self._warmup_lock = asyncio.Lock()
76+
self._warmup_task: Optional[asyncio.Task] = None
77+
self._warmup_completed = False
78+
self._last_warmup_resolution: Optional[tuple[int, int]] = None
7579

7680
@property
7781
def state(self) -> PipelineState:
@@ -155,6 +159,10 @@ async def warmup(
155159
await self.state_manager.transition_to(PipelineState.ERROR)
156160
raise
157161
finally:
162+
if warmup_successful:
163+
self._warmup_completed = True
164+
self._last_warmup_resolution = (self.width, self.height)
165+
158166
if transitioned and warmup_successful:
159167
try:
160168
await self.state_manager.transition_to(PipelineState.READY)
@@ -168,6 +176,63 @@ async def warmup(
168176
except Exception:
169177
logger.exception("Failed to restore STREAMING state after warmup")
170178

179+
async def ensure_warmup(self, width: Optional[int] = None, height: Optional[int] = None):
180+
"""Ensure the pipeline has been warmed up for the given resolution."""
181+
if width and width > 0:
182+
self.width = int(width)
183+
if height and height > 0:
184+
self.height = int(height)
185+
186+
if self._warmup_completed and self._last_warmup_resolution:
187+
if (self.width, self.height) != self._last_warmup_resolution:
188+
self._warmup_completed = False
189+
190+
if self._warmup_completed:
191+
return
192+
193+
if not self.state_manager.is_initialized():
194+
logger.debug("Skipping warmup scheduling - pipeline not initialized")
195+
return
196+
197+
async with self._warmup_lock:
198+
if self._warmup_completed:
199+
return
200+
if not self.state_manager.is_initialized():
201+
return
202+
if self._warmup_task and not self._warmup_task.done():
203+
return
204+
205+
logger.info("Scheduling pipeline warmup for %sx%s", self.width, self.height)
206+
self.disable_ingest()
207+
self._warmup_task = asyncio.create_task(self._run_background_warmup())
208+
209+
async def _run_background_warmup(self):
210+
try:
211+
await self.warmup()
212+
except asyncio.CancelledError:
213+
logger.debug("Pipeline warmup task cancelled")
214+
raise
215+
except Exception:
216+
logger.exception("Pipeline warmup failed")
217+
finally:
218+
self.enable_ingest()
219+
self._warmup_task = None
220+
221+
async def _reset_warmup_state(self):
222+
"""Reset warmup bookkeeping and cancel any in-flight warmup tasks."""
223+
async with self._warmup_lock:
224+
if self._warmup_task and not self._warmup_task.done():
225+
self._warmup_task.cancel()
226+
try:
227+
await self._warmup_task
228+
except asyncio.CancelledError:
229+
pass
230+
except Exception:
231+
logger.debug("Warmup task raised during cancellation", exc_info=True)
232+
self._warmup_task = None
233+
self._warmup_completed = False
234+
self._last_warmup_resolution = None
235+
171236
async def _run_warmup(
172237
self,
173238
*,
@@ -266,6 +331,8 @@ async def set_prompts(
266331
skip_warmup: Skip automatic warmup even if auto_warmup is enabled
267332
"""
268333
try:
334+
await self._reset_warmup_state()
335+
269336
prompt_list = prompts if isinstance(prompts, list) else [prompts]
270337
await self.client.set_prompts(prompt_list)
271338

@@ -312,6 +379,8 @@ async def update_prompts(
312379
if was_streaming and should_warmup:
313380
await self.state_manager.transition_to(PipelineState.READY)
314381

382+
await self._reset_warmup_state()
383+
315384
prompt_list = prompts if isinstance(prompts, list) else [prompts]
316385
await self.client.update_prompts(prompt_list)
317386

@@ -775,6 +844,7 @@ async def cleanup(self):
775844
# Clear cached modalities and I/O capabilities since we're resetting
776845
self._cached_modalities = None
777846
self._cached_io_capabilities = None
847+
await self._reset_warmup_state()
778848

779849
# Clear pipeline queues
780850
await self._clear_pipeline_queues()

0 commit comments

Comments
 (0)