|
1 | 1 | import os |
2 | 2 | import json |
3 | 3 | import torch |
4 | | -import asyncio |
5 | 4 | import numpy as np |
6 | 5 | from PIL import Image |
7 | 6 | from typing import Union |
8 | 7 | from pydantic import BaseModel, field_validator |
9 | 8 | import pathlib |
10 | 9 |
|
11 | 10 | from .interface import Pipeline |
12 | | -from comfystream.client import ComfyStreamClient |
| 11 | +from comfystream.pipeline import Pipeline as ComfyStreamPipeline |
13 | 12 | from trickle import VideoFrame, VideoOutput |
14 | 13 |
|
15 | 14 | import logging |
@@ -52,52 +51,43 @@ def validate_prompt(cls, v) -> dict: |
52 | 51 | class ComfyUI(Pipeline): |
53 | 52 | def __init__(self): |
54 | 53 | comfy_ui_workspace = os.getenv(COMFY_UI_WORKSPACE_ENV) |
55 | | - self.client = ComfyStreamClient(cwd=comfy_ui_workspace) |
| 54 | + self.pipeline = ComfyStreamPipeline(width=512, height=512, cwd=comfy_ui_workspace) |
56 | 55 | self.params: ComfyUIParams |
57 | | - self.video_incoming_frames: asyncio.Queue[VideoOutput] = asyncio.Queue() |
58 | 56 |
|
59 | 57 | async def initialize(self, **params): |
60 | 58 | new_params = ComfyUIParams(**params) |
61 | 59 | logging.info(f"Initializing ComfyUI Pipeline with prompt: {new_params.prompt}") |
62 | | - # TODO: currently its a single prompt, but need to support multiple prompts |
63 | | - await self.client.set_prompts([new_params.prompt]) |
| 60 | + await self.pipeline.set_prompts([new_params.prompt]) |
64 | 61 | self.params = new_params |
65 | 62 |
|
66 | 63 | # Warm up the pipeline |
67 | | - dummy_frame = VideoFrame(None, 0, 0) |
68 | | - dummy_frame.side_data.input = torch.randn(1, 512, 512, 3) |
69 | | - |
70 | | - for _ in range(WARMUP_RUNS): |
71 | | - self.client.put_video_input(dummy_frame) |
72 | | - _ = await self.client.get_video_output() |
| 64 | + await self.pipeline.warm_video() |
73 | 65 | logging.info("Pipeline initialization and warmup complete") |
74 | 66 |
|
75 | 67 | async def put_video_frame(self, frame: VideoFrame, request_id: str): |
| 68 | + # Convert VideoFrame to format expected by comfystream |
76 | 69 | image_np = np.array(frame.image.convert("RGB")).astype(np.float32) / 255.0 |
77 | 70 | frame.side_data.input = torch.tensor(image_np).unsqueeze(0) |
78 | 71 | frame.side_data.skipped = True |
79 | | - self.client.put_video_input(frame) |
80 | | - await self.video_incoming_frames.put(VideoOutput(frame, request_id)) |
81 | | - |
82 | | - async def get_processed_video_frame(self): |
83 | | - result_tensor = await self.client.get_video_output() |
84 | | - out = await self.video_incoming_frames.get() |
85 | | - while out.frame.side_data.skipped: |
86 | | - out = await self.video_incoming_frames.get() |
| 72 | + frame.side_data.request_id = request_id |
| 73 | + await self.pipeline.put_video_frame(frame) |
87 | 74 |
|
| 75 | + async def get_processed_video_frame(self) -> VideoOutput: |
| 76 | + processed_frame = await self.pipeline.get_processed_video_frame() |
| 77 | + # Convert back to VideoOutput format |
| 78 | + result_tensor = processed_frame.side_data.input |
88 | 79 | result_tensor = result_tensor.squeeze(0) |
89 | 80 | result_image_np = (result_tensor * 255).byte() |
90 | 81 | result_image = Image.fromarray(result_image_np.cpu().numpy()) |
91 | | - return out.replace_image(result_image) |
| 82 | + return VideoOutput(processed_frame, processed_frame.side_data.request_id).replace_image(result_image) |
92 | 83 |
|
93 | 84 | async def update_params(self, **params): |
94 | 85 | new_params = ComfyUIParams(**params) |
95 | 86 | logging.info(f"Updating ComfyUI Pipeline Prompt: {new_params.prompt}") |
96 | | - # TODO: currently its a single prompt, but need to support multiple prompts |
97 | | - await self.client.update_prompts([new_params.prompt]) |
| 87 | + await self.pipeline.update_prompts([new_params.prompt]) |
98 | 88 | self.params = new_params |
99 | 89 |
|
100 | 90 | async def stop(self): |
101 | 91 | logging.info("Stopping ComfyUI pipeline") |
102 | | - await self.client.stop() |
| 92 | + await self.pipeline.cleanup() |
103 | 93 | logging.info("ComfyUI pipeline stopped") |
0 commit comments