diff --git a/nodes/tensor_utils/save_tensor.py b/nodes/tensor_utils/save_tensor.py index 3a021aa5..1c39e629 100644 --- a/nodes/tensor_utils/save_tensor.py +++ b/nodes/tensor_utils/save_tensor.py @@ -1,3 +1,4 @@ +import numpy as np import torch from comfystream import tensor_cache @@ -21,6 +22,49 @@ def INPUT_TYPES(s): def IS_CHANGED(s): return float("nan") + @staticmethod + def _split_images(images): + """Yield individual images for batched tensors/lists without changing interface.""" + # Torch tensor inputs with optional batch dimension in dim 0 + if isinstance(images, torch.Tensor): + if images.dim() >= 4: + for img in images: + yield img + return + yield images + return + + # Numpy arrays (should rarely occur, but handled for completeness) + if isinstance(images, np.ndarray): + if images.ndim >= 4 and images.shape[0] > 1: + for img in images: + yield img + return + yield images + return + + # Lists/tuples of images already separated + if isinstance(images, (list, tuple)): + for img in images: + yield img + return + + # Fallback to passing through any other type as-is + yield images + def execute(self, images: torch.Tensor): - tensor_cache.image_outputs.put_nowait(images) + for img in self._split_images(images): + # Schedule the put operation on the main event loop thread safely + if tensor_cache.main_loop: + tensor_cache.main_loop.call_soon_threadsafe( + tensor_cache.image_outputs.put_nowait, img + ) + else: + # Fallback implementation (mostly for tests without init or direct execution) + try: + tensor_cache.image_outputs.put_nowait(img) + except RuntimeError: + # If we are in a thread with no loop, this might fail or be unsafe, + # but if main_loop is not set we have few options. + pass return images diff --git a/server/app.py b/server/app.py index b93e35ee..8205cab3 100644 --- a/server/app.py +++ b/server/app.py @@ -51,25 +51,31 @@ class VideoStreamTrack(MediaStreamTrack): kind = "video" - def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + def __init__(self, track: MediaStreamTrack | None, pipeline: Pipeline): """Initialize the VideoStreamTrack. Args: - track: The underlying media stream track. + track: The underlying media stream track (None if generative). pipeline: The processing pipeline to apply to each video frame. """ super().__init__() self.track = track self.pipeline = pipeline - self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track.id) + + track_id = track.id if track else self.id + self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track_id) self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source video track ended, stopping collection") - await cancel_collect_frames(self) + + if track: + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source video track ended, stopping collection") + await cancel_collect_frames(self) + else: + self.collect_task = None async def collect_frames(self): """Collect video frames from the underlying track and pass them to @@ -153,19 +159,25 @@ async def recv(self): class AudioStreamTrack(MediaStreamTrack): kind = "audio" - def __init__(self, track: MediaStreamTrack, pipeline): + def __init__(self, track: MediaStreamTrack | None, pipeline): super().__init__() self.track = track self.pipeline = pipeline self.running = True - logger.info(f"AudioStreamTrack created for track {track.id}") - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends - @track.on("ended") - async def on_ended(): - logger.info("Source audio track ended, stopping collection") - await cancel_collect_frames(self) + + track_id = track.id if track else self.id + logger.info(f"AudioStreamTrack created for track {track_id}") + + if track: + self.collect_task = asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends + @track.on("ended") + async def on_ended(): + logger.info("Source audio track ended, stopping collection") + await cancel_collect_frames(self) + else: + self.collect_task = None async def collect_frames(self): """Collect audio frames from the underlying track and pass them to @@ -285,7 +297,18 @@ async def offer(request): # Add transceivers for both audio and video if present in the offer if "m=video" in offer.sdp: logger.debug("[Offer] Adding video transceiver") - video_transceiver = pc.addTransceiver("video", direction="sendrecv") + + track_or_kind = "video" + if not is_noop_mode and not pipeline.accepts_video_input() and pipeline.produces_video_output(): + logger.info("[Offer] Creating Generative Video Track") + gen_track = VideoStreamTrack(None, pipeline) + tracks["video"] = gen_track + track_or_kind = gen_track + + # Store video track in app for stats + request.app["video_tracks"][gen_track.id] = gen_track + + video_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv") caps = RTCRtpSender.getCapabilities("video") prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) video_transceiver.setCodecPreferences(prefs) @@ -296,7 +319,15 @@ async def offer(request): if "m=audio" in offer.sdp: logger.debug("[Offer] Adding audio transceiver") - audio_transceiver = pc.addTransceiver("audio", direction="sendrecv") + + track_or_kind = "audio" + if not is_noop_mode and not pipeline.accepts_audio_input() and pipeline.produces_audio_output(): + logger.info("[Offer] Creating Generative Audio Track") + gen_track = AudioStreamTrack(None, pipeline) + tracks["audio"] = gen_track + track_or_kind = gen_track + + audio_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv") audio_caps = RTCRtpSender.getCapabilities("audio") # Prefer Opus for audio audio_prefs = [codec for codec in audio_caps.codecs if codec.name == "opus"] @@ -654,6 +685,9 @@ async def on_startup(app: web.Application): if app["media_ports"]: patch_loop_datagram(app["media_ports"]) + from comfystream import tensor_cache + tensor_cache.init(asyncio.get_running_loop()) + app["pipeline"] = Pipeline( width=512, height=512, diff --git a/src/comfystream/modalities.py b/src/comfystream/modalities.py index ded16829..c48668b0 100644 --- a/src/comfystream/modalities.py +++ b/src/comfystream/modalities.py @@ -25,7 +25,7 @@ class WorkflowModality(TypedDict): "audio_input": {"LoadAudioTensor"}, "audio_output": {"SaveAudioTensor"}, # Text nodes - "text_input": set(), # No text input nodes currently + "text_input": {"PrimitiveString"}, # Basic text input node "text_output": {"SaveTextTensor"}, } diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index 9143e591..2c234c50 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Union import av +from fractions import Fraction import numpy as np import torch @@ -76,6 +77,7 @@ def __init__( self._warmup_task: Optional[asyncio.Task] = None self._warmup_completed = False self._last_warmup_resolution: Optional[tuple[int, int]] = None + self._generated_pts = 0 @property def state(self) -> PipelineState: @@ -682,6 +684,18 @@ async def get_processed_video_frame(self) -> av.VideoFrame: Returns: The processed video frame, or original frame if no processing needed """ + # Handle generative video case (no input, but produces output) + if not self.accepts_video_input() and self.produces_video_output(): + async with temporary_log_level("comfy", self._comfyui_inference_log_level): + out_tensor = await self.client.get_video_output() + + processed_frame = self.video_postprocess(out_tensor) + processed_frame.pts = self._generated_pts + processed_frame.time_base = Fraction(1, 30) + self._generated_pts += 1 + + return processed_frame + frame = await self.video_incoming_frames.get() # Skip frames that were marked as skipped diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 609f98b8..d8d5c7ed 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -13,3 +13,9 @@ audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() text_outputs: AsyncQueue[str] = AsyncQueue() + +main_loop = None + +def init(loop): + global main_loop + main_loop = loop diff --git a/test/test_save_tensor.py b/test/test_save_tensor.py new file mode 100644 index 00000000..df548adb --- /dev/null +++ b/test/test_save_tensor.py @@ -0,0 +1,46 @@ +import asyncio + +import pytest +import torch + +from comfystream import tensor_cache +from nodes.tensor_utils.save_tensor import SaveTensor + + +async def _drain_image_queue(): + while True: + try: + tensor_cache.image_outputs.get_nowait() + except asyncio.QueueEmpty: + break + + +@pytest.mark.asyncio +async def test_save_tensor_splits_batched_images(): + await _drain_image_queue() + + images = torch.rand(2, 8, 8, 3) + + SaveTensor().execute(images) + + first = await tensor_cache.image_outputs.get() + second = await tensor_cache.image_outputs.get() + + assert torch.equal(first, images[0]) + assert torch.equal(second, images[1]) + assert tensor_cache.image_outputs.empty() + + +@pytest.mark.asyncio +async def test_save_tensor_passthrough_single_image(): + await _drain_image_queue() + + images = torch.rand(1, 4, 4, 3) + + SaveTensor().execute(images) + + queued = await tensor_cache.image_outputs.get() + + assert torch.equal(queued, images) + assert tensor_cache.image_outputs.empty() +