Skip to content

Commit f9d2e4b

Browse files
authored
convert WanCameraEmbedding node to V3 schema (#9714)
1 parent 45bc1f5 commit f9d2e4b

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

comfy_extras/nodes_camera_trajectory.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import torch
33
import numpy as np
44
from einops import rearrange
5+
from typing_extensions import override
56
import comfy.model_management
67

8+
from comfy_api.latest import ComfyExtension, io
79

810

9-
MAX_RESOLUTION = nodes.MAX_RESOLUTION
10-
1111
CAMERA_DICT = {
1212
"base_T_norm": 1.5,
1313
"base_angle": np.pi/3,
@@ -148,32 +148,47 @@ def compute_R_form_rad_angle(angles):
148148
RT = np.stack(RT)
149149
return RT
150150

151-
class WanCameraEmbedding:
151+
class WanCameraEmbedding(io.ComfyNode):
152+
@classmethod
153+
def define_schema(cls):
154+
return io.Schema(
155+
node_id="WanCameraEmbedding",
156+
category="camera",
157+
inputs=[
158+
io.Combo.Input(
159+
"camera_pose",
160+
options=[
161+
"Static",
162+
"Pan Up",
163+
"Pan Down",
164+
"Pan Left",
165+
"Pan Right",
166+
"Zoom In",
167+
"Zoom Out",
168+
"Anti Clockwise (ACW)",
169+
"ClockWise (CW)",
170+
],
171+
default="Static",
172+
),
173+
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
174+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
175+
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
176+
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
177+
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
178+
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
179+
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
180+
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
181+
],
182+
outputs=[
183+
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
184+
io.Int.Output(display_name="width"),
185+
io.Int.Output(display_name="height"),
186+
io.Int.Output(display_name="length"),
187+
],
188+
)
189+
152190
@classmethod
153-
def INPUT_TYPES(cls):
154-
return {
155-
"required": {
156-
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
157-
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
158-
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
159-
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
160-
},
161-
"optional":{
162-
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
163-
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
164-
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
165-
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
166-
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
167-
}
168-
169-
}
170-
171-
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
172-
RETURN_NAMES = ("camera_embedding","width","height","length")
173-
FUNCTION = "run"
174-
CATEGORY = "camera"
175-
176-
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
191+
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
177192
"""
178193
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
179194
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
@@ -210,9 +225,15 @@ def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx
210225
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
211226
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
212227

213-
return (control_camera_video, width, height, length)
228+
return io.NodeOutput(control_camera_video, width, height, length)
214229

215230

216-
NODE_CLASS_MAPPINGS = {
217-
"WanCameraEmbedding": WanCameraEmbedding,
218-
}
231+
class CameraTrajectoryExtension(ComfyExtension):
232+
@override
233+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
234+
return [
235+
WanCameraEmbedding,
236+
]
237+
238+
async def comfy_entrypoint() -> CameraTrajectoryExtension:
239+
return CameraTrajectoryExtension()

0 commit comments

Comments
 (0)