Skip to content

Commit 3e68bc3

Browse files
authored
convert nodes_torch_compile.py to V3 schema (#10173)
1 parent c2c5a7d commit 3e68bc3

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed
Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,39 @@
1+
from typing_extensions import override
2+
from comfy_api.latest import ComfyExtension, io
13
from comfy_api.torch_helpers import set_torch_compile_wrapper
24

35

4-
class TorchCompileModel:
6+
class TorchCompileModel(io.ComfyNode):
57
@classmethod
6-
def INPUT_TYPES(s):
7-
return {"required": { "model": ("MODEL",),
8-
"backend": (["inductor", "cudagraphs"],),
9-
}}
10-
RETURN_TYPES = ("MODEL",)
11-
FUNCTION = "patch"
8+
def define_schema(cls) -> io.Schema:
9+
return io.Schema(
10+
node_id="TorchCompileModel",
11+
category="_for_testing",
12+
inputs=[
13+
io.Model.Input("model"),
14+
io.Combo.Input(
15+
"backend",
16+
options=["inductor", "cudagraphs"],
17+
),
18+
],
19+
outputs=[io.Model.Output()],
20+
is_experimental=True,
21+
)
1222

13-
CATEGORY = "_for_testing"
14-
EXPERIMENTAL = True
15-
16-
def patch(self, model, backend):
23+
@classmethod
24+
def execute(cls, model, backend) -> io.NodeOutput:
1725
m = model.clone()
1826
set_torch_compile_wrapper(model=m, backend=backend)
19-
return (m, )
27+
return io.NodeOutput(m)
28+
29+
30+
class TorchCompileExtension(ComfyExtension):
31+
@override
32+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
33+
return [
34+
TorchCompileModel,
35+
]
36+
2037

21-
NODE_CLASS_MAPPINGS = {
22-
"TorchCompileModel": TorchCompileModel,
23-
}
38+
async def comfy_entrypoint() -> TorchCompileExtension:
39+
return TorchCompileExtension()

0 commit comments

Comments
 (0)