Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion nodes/tensor_utils/save_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

from comfystream import tensor_cache
Expand All @@ -21,6 +22,49 @@ def INPUT_TYPES(s):
def IS_CHANGED(s):
return float("nan")

@staticmethod
def _split_images(images):
"""Yield individual images for batched tensors/lists without changing interface."""
# Torch tensor inputs with optional batch dimension in dim 0
if isinstance(images, torch.Tensor):
if images.dim() >= 4:
for img in images:
yield img
return
yield images
return

# Numpy arrays (should rarely occur, but handled for completeness)
if isinstance(images, np.ndarray):
if images.ndim >= 4 and images.shape[0] > 1:
for img in images:
yield img
return
yield images
return

# Lists/tuples of images already separated
if isinstance(images, (list, tuple)):
for img in images:
yield img
return

# Fallback to passing through any other type as-is
yield images

def execute(self, images: torch.Tensor):
tensor_cache.image_outputs.put_nowait(images)
for img in self._split_images(images):
# Schedule the put operation on the main event loop thread safely
if tensor_cache.main_loop:
tensor_cache.main_loop.call_soon_threadsafe(
tensor_cache.image_outputs.put_nowait, img
)
else:
# Fallback implementation (mostly for tests without init or direct execution)
try:
tensor_cache.image_outputs.put_nowait(img)
except RuntimeError:
# If we are in a thread with no loop, this might fail or be unsafe,
# but if main_loop is not set we have few options.
pass
return images
76 changes: 55 additions & 21 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,31 @@ class VideoStreamTrack(MediaStreamTrack):

kind = "video"

def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
def __init__(self, track: MediaStreamTrack | None, pipeline: Pipeline):
"""Initialize the VideoStreamTrack.

Args:
track: The underlying media stream track.
track: The underlying media stream track (None if generative).
pipeline: The processing pipeline to apply to each video frame.
"""
super().__init__()
self.track = track
self.pipeline = pipeline
self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track.id)

track_id = track.id if track else self.id
self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track_id)
self.running = True
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source video track ended, stopping collection")
await cancel_collect_frames(self)

if track:
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source video track ended, stopping collection")
await cancel_collect_frames(self)
else:
self.collect_task = None

async def collect_frames(self):
"""Collect video frames from the underlying track and pass them to
Expand Down Expand Up @@ -153,19 +159,25 @@ async def recv(self):
class AudioStreamTrack(MediaStreamTrack):
kind = "audio"

def __init__(self, track: MediaStreamTrack, pipeline):
def __init__(self, track: MediaStreamTrack | None, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
self.running = True
logger.info(f"AudioStreamTrack created for track {track.id}")
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source audio track ended, stopping collection")
await cancel_collect_frames(self)

track_id = track.id if track else self.id
logger.info(f"AudioStreamTrack created for track {track_id}")

if track:
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source audio track ended, stopping collection")
await cancel_collect_frames(self)
else:
self.collect_task = None

async def collect_frames(self):
"""Collect audio frames from the underlying track and pass them to
Expand Down Expand Up @@ -285,7 +297,18 @@ async def offer(request):
# Add transceivers for both audio and video if present in the offer
if "m=video" in offer.sdp:
logger.debug("[Offer] Adding video transceiver")
video_transceiver = pc.addTransceiver("video", direction="sendrecv")

track_or_kind = "video"
if not is_noop_mode and not pipeline.accepts_video_input() and pipeline.produces_video_output():
logger.info("[Offer] Creating Generative Video Track")
gen_track = VideoStreamTrack(None, pipeline)
tracks["video"] = gen_track
track_or_kind = gen_track

# Store video track in app for stats
request.app["video_tracks"][gen_track.id] = gen_track

video_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv")
caps = RTCRtpSender.getCapabilities("video")
prefs = list(filter(lambda x: x.name == "H264", caps.codecs))
video_transceiver.setCodecPreferences(prefs)
Expand All @@ -296,7 +319,15 @@ async def offer(request):

if "m=audio" in offer.sdp:
logger.debug("[Offer] Adding audio transceiver")
audio_transceiver = pc.addTransceiver("audio", direction="sendrecv")

track_or_kind = "audio"
if not is_noop_mode and not pipeline.accepts_audio_input() and pipeline.produces_audio_output():
logger.info("[Offer] Creating Generative Audio Track")
gen_track = AudioStreamTrack(None, pipeline)
tracks["audio"] = gen_track
track_or_kind = gen_track

audio_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv")
audio_caps = RTCRtpSender.getCapabilities("audio")
# Prefer Opus for audio
audio_prefs = [codec for codec in audio_caps.codecs if codec.name == "opus"]
Expand Down Expand Up @@ -654,6 +685,9 @@ async def on_startup(app: web.Application):
if app["media_ports"]:
patch_loop_datagram(app["media_ports"])

from comfystream import tensor_cache
tensor_cache.init(asyncio.get_running_loop())

app["pipeline"] = Pipeline(
width=512,
height=512,
Expand Down
2 changes: 1 addition & 1 deletion src/comfystream/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class WorkflowModality(TypedDict):
"audio_input": {"LoadAudioTensor"},
"audio_output": {"SaveAudioTensor"},
# Text nodes
"text_input": set(), # No text input nodes currently
"text_input": {"PrimitiveString"}, # Basic text input node
"text_output": {"SaveTextTensor"},
}

Expand Down
14 changes: 14 additions & 0 deletions src/comfystream/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional, Set, Union

import av
from fractions import Fraction
import numpy as np
import torch

Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self._warmup_task: Optional[asyncio.Task] = None
self._warmup_completed = False
self._last_warmup_resolution: Optional[tuple[int, int]] = None
self._generated_pts = 0

@property
def state(self) -> PipelineState:
Expand Down Expand Up @@ -682,6 +684,18 @@ async def get_processed_video_frame(self) -> av.VideoFrame:
Returns:
The processed video frame, or original frame if no processing needed
"""
# Handle generative video case (no input, but produces output)
if not self.accepts_video_input() and self.produces_video_output():
async with temporary_log_level("comfy", self._comfyui_inference_log_level):
out_tensor = await self.client.get_video_output()

processed_frame = self.video_postprocess(out_tensor)
processed_frame.pts = self._generated_pts
processed_frame.time_base = Fraction(1, 30)
self._generated_pts += 1

return processed_frame

frame = await self.video_incoming_frames.get()

# Skip frames that were marked as skipped
Expand Down
6 changes: 6 additions & 0 deletions src/comfystream/tensor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@
audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue()

text_outputs: AsyncQueue[str] = AsyncQueue()

main_loop = None

def init(loop):
global main_loop
main_loop = loop
46 changes: 46 additions & 0 deletions test/test_save_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import asyncio

import pytest
import torch

from comfystream import tensor_cache
from nodes.tensor_utils.save_tensor import SaveTensor


async def _drain_image_queue():
while True:
try:
tensor_cache.image_outputs.get_nowait()
except asyncio.QueueEmpty:
break


@pytest.mark.asyncio
async def test_save_tensor_splits_batched_images():
await _drain_image_queue()

images = torch.rand(2, 8, 8, 3)

SaveTensor().execute(images)

first = await tensor_cache.image_outputs.get()
second = await tensor_cache.image_outputs.get()

assert torch.equal(first, images[0])
assert torch.equal(second, images[1])
assert tensor_cache.image_outputs.empty()


@pytest.mark.asyncio
async def test_save_tensor_passthrough_single_image():
await _drain_image_queue()

images = torch.rand(1, 4, 4, 3)

SaveTensor().execute(images)

queued = await tensor_cache.image_outputs.get()

assert torch.equal(queued, images)
assert tensor_cache.image_outputs.empty()