Skip to content

Commit b1fa192

Browse files
authored
convert nodes_stable3d.py to V3 schema (#10204)
1 parent 2ed74f7 commit b1fa192

File tree

1 file changed

+89
-66
lines changed

1 file changed

+89
-66
lines changed

comfy_extras/nodes_stable3d.py

Lines changed: 89 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import nodes
33
import comfy.utils
4+
from typing_extensions import override
5+
from comfy_api.latest import ComfyExtension, io
46

57
def camera_embeddings(elevation, azimuth):
68
elevation = torch.as_tensor([elevation])
@@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth):
2022
return embeddings
2123

2224

23-
class StableZero123_Conditioning:
25+
class StableZero123_Conditioning(io.ComfyNode):
2426
@classmethod
25-
def INPUT_TYPES(s):
26-
return {"required": { "clip_vision": ("CLIP_VISION",),
27-
"init_image": ("IMAGE",),
28-
"vae": ("VAE",),
29-
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
30-
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
31-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
32-
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
33-
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
34-
}}
35-
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
36-
RETURN_NAMES = ("positive", "negative", "latent")
37-
38-
FUNCTION = "encode"
39-
40-
CATEGORY = "conditioning/3d_models"
41-
42-
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
27+
def define_schema(cls):
28+
return io.Schema(
29+
node_id="StableZero123_Conditioning",
30+
category="conditioning/3d_models",
31+
inputs=[
32+
io.ClipVision.Input("clip_vision"),
33+
io.Image.Input("init_image"),
34+
io.Vae.Input("vae"),
35+
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
36+
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
37+
io.Int.Input("batch_size", default=1, min=1, max=4096),
38+
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
39+
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
40+
],
41+
outputs=[
42+
io.Conditioning.Output(display_name="positive"),
43+
io.Conditioning.Output(display_name="negative"),
44+
io.Latent.Output(display_name="latent")
45+
]
46+
)
47+
48+
@classmethod
49+
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput:
4350
output = clip_vision.encode_image(init_image)
4451
pooled = output.image_embeds.unsqueeze(0)
4552
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -51,30 +58,35 @@ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevat
5158
positive = [[cond, {"concat_latent_image": t}]]
5259
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
5360
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
54-
return (positive, negative, {"samples":latent})
61+
return io.NodeOutput(positive, negative, {"samples":latent})
62+
63+
class StableZero123_Conditioning_Batched(io.ComfyNode):
64+
@classmethod
65+
def define_schema(cls):
66+
return io.Schema(
67+
node_id="StableZero123_Conditioning_Batched",
68+
category="conditioning/3d_models",
69+
inputs=[
70+
io.ClipVision.Input("clip_vision"),
71+
io.Image.Input("init_image"),
72+
io.Vae.Input("vae"),
73+
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
74+
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
75+
io.Int.Input("batch_size", default=1, min=1, max=4096),
76+
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
77+
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
78+
io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
79+
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
80+
],
81+
outputs=[
82+
io.Conditioning.Output(display_name="positive"),
83+
io.Conditioning.Output(display_name="negative"),
84+
io.Latent.Output(display_name="latent")
85+
]
86+
)
5587

56-
class StableZero123_Conditioning_Batched:
5788
@classmethod
58-
def INPUT_TYPES(s):
59-
return {"required": { "clip_vision": ("CLIP_VISION",),
60-
"init_image": ("IMAGE",),
61-
"vae": ("VAE",),
62-
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
63-
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
64-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
65-
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
66-
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
67-
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
68-
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
69-
}}
70-
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
71-
RETURN_NAMES = ("positive", "negative", "latent")
72-
73-
FUNCTION = "encode"
74-
75-
CATEGORY = "conditioning/3d_models"
76-
77-
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
89+
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput:
7890
output = clip_vision.encode_image(init_image)
7991
pooled = output.image_embeds.unsqueeze(0)
8092
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -93,27 +105,32 @@ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevat
93105
positive = [[cond, {"concat_latent_image": t}]]
94106
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
95107
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
96-
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
108+
return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
97109

98-
class SV3D_Conditioning:
110+
class SV3D_Conditioning(io.ComfyNode):
99111
@classmethod
100-
def INPUT_TYPES(s):
101-
return {"required": { "clip_vision": ("CLIP_VISION",),
102-
"init_image": ("IMAGE",),
103-
"vae": ("VAE",),
104-
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
105-
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
106-
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
107-
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
108-
}}
109-
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
110-
RETURN_NAMES = ("positive", "negative", "latent")
111-
112-
FUNCTION = "encode"
113-
114-
CATEGORY = "conditioning/3d_models"
115-
116-
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
112+
def define_schema(cls):
113+
return io.Schema(
114+
node_id="SV3D_Conditioning",
115+
category="conditioning/3d_models",
116+
inputs=[
117+
io.ClipVision.Input("clip_vision"),
118+
io.Image.Input("init_image"),
119+
io.Vae.Input("vae"),
120+
io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
121+
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
122+
io.Int.Input("video_frames", default=21, min=1, max=4096),
123+
io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False)
124+
],
125+
outputs=[
126+
io.Conditioning.Output(display_name="positive"),
127+
io.Conditioning.Output(display_name="negative"),
128+
io.Latent.Output(display_name="latent")
129+
]
130+
)
131+
132+
@classmethod
133+
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput:
117134
output = clip_vision.encode_image(init_image)
118135
pooled = output.image_embeds.unsqueeze(0)
119136
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@@ -133,11 +150,17 @@ def encode(self, clip_vision, init_image, vae, width, height, video_frames, elev
133150
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
134151
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
135152
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
136-
return (positive, negative, {"samples":latent})
153+
return io.NodeOutput(positive, negative, {"samples":latent})
154+
137155

156+
class Stable3DExtension(ComfyExtension):
157+
@override
158+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
159+
return [
160+
StableZero123_Conditioning,
161+
StableZero123_Conditioning_Batched,
162+
SV3D_Conditioning,
163+
]
138164

139-
NODE_CLASS_MAPPINGS = {
140-
"StableZero123_Conditioning": StableZero123_Conditioning,
141-
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
142-
"SV3D_Conditioning": SV3D_Conditioning,
143-
}
165+
async def comfy_entrypoint() -> Stable3DExtension:
166+
return Stable3DExtension()

0 commit comments

Comments
 (0)