Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions comfy_extras/nodes_lora_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io

CLAMP_QUANTILE = 0.99

Expand Down Expand Up @@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd

class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class LoraSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",
optional=True,
),
io.Clip.Input(
"text_encoder_diff",
tooltip="The CLIPSubtract output to be converted to a lora.",
optional=True,
),
],
is_experimental=True,
is_output_node=True,
)

@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True

CATEGORY = "_for_testing"

def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
if model_diff is None and text_encoder_diff is None:
return {}
return io.NodeOutput()

lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())

output_sd = {}
if model_diff is not None:
Expand All @@ -108,12 +118,16 @@ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, tex
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
return io.NodeOutput()


class LoraSaveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoraSave,
]

NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}

NODE_DISPLAY_NAME_MAPPINGS = {
"LoraSave": "Extract and Save Lora"
}
async def comfy_entrypoint() -> LoraSaveExtension:
return LoraSaveExtension()