|
| 1 | +from typing_extensions import override |
| 2 | +from comfy_api.latest import ComfyExtension, io |
1 | 3 | from comfy_api.torch_helpers import set_torch_compile_wrapper
|
2 | 4 |
|
3 | 5 |
|
4 |
| -class TorchCompileModel: |
| 6 | +class TorchCompileModel(io.ComfyNode): |
5 | 7 | @classmethod
|
6 |
| - def INPUT_TYPES(s): |
7 |
| - return {"required": { "model": ("MODEL",), |
8 |
| - "backend": (["inductor", "cudagraphs"],), |
9 |
| - }} |
10 |
| - RETURN_TYPES = ("MODEL",) |
11 |
| - FUNCTION = "patch" |
| 8 | + def define_schema(cls) -> io.Schema: |
| 9 | + return io.Schema( |
| 10 | + node_id="TorchCompileModel", |
| 11 | + category="_for_testing", |
| 12 | + inputs=[ |
| 13 | + io.Model.Input("model"), |
| 14 | + io.Combo.Input( |
| 15 | + "backend", |
| 16 | + options=["inductor", "cudagraphs"], |
| 17 | + ), |
| 18 | + ], |
| 19 | + outputs=[io.Model.Output()], |
| 20 | + is_experimental=True, |
| 21 | + ) |
12 | 22 |
|
13 |
| - CATEGORY = "_for_testing" |
14 |
| - EXPERIMENTAL = True |
15 |
| - |
16 |
| - def patch(self, model, backend): |
| 23 | + @classmethod |
| 24 | + def execute(cls, model, backend) -> io.NodeOutput: |
17 | 25 | m = model.clone()
|
18 | 26 | set_torch_compile_wrapper(model=m, backend=backend)
|
19 |
| - return (m, ) |
| 27 | + return io.NodeOutput(m) |
| 28 | + |
| 29 | + |
| 30 | +class TorchCompileExtension(ComfyExtension): |
| 31 | + @override |
| 32 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 33 | + return [ |
| 34 | + TorchCompileModel, |
| 35 | + ] |
| 36 | + |
20 | 37 |
|
21 |
| -NODE_CLASS_MAPPINGS = { |
22 |
| - "TorchCompileModel": TorchCompileModel, |
23 |
| -} |
| 38 | +async def comfy_entrypoint() -> TorchCompileExtension: |
| 39 | + return TorchCompileExtension() |
0 commit comments