Skip to content

Commit eaf68c9

Browse files
Make lora training work on Z Image and remove some redundant nodes. (comfyanonymous#10927)
1 parent cc6a8dc commit eaf68c9

File tree

2 files changed

+3
-103
lines changed

2 files changed

+3
-103
lines changed

comfy/ldm/lumina/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def patchify_and_embed(
509509

510510
if self.pad_tokens_multiple is not None:
511511
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
512-
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
512+
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
513513

514514
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
515515
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
@@ -525,7 +525,7 @@ def patchify_and_embed(
525525

526526
if self.pad_tokens_multiple is not None:
527527
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
528-
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
528+
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
529529
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
530530

531531
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)

comfy_extras/nodes_dataset.py

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import os
3-
import math
43
import json
54

65
import numpy as np
@@ -624,79 +623,6 @@ def _group_process(cls, texts, **kwargs):
624623
# ========== Image Transform Nodes ==========
625624

626625

627-
class ResizeImagesToSameSizeNode(ImageProcessingNode):
628-
node_id = "ResizeImagesToSameSize"
629-
display_name = "Resize Images to Same Size"
630-
description = "Resize all images to the same width and height."
631-
extra_inputs = [
632-
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
633-
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
634-
io.Combo.Input(
635-
"mode",
636-
options=["stretch", "crop_center", "pad"],
637-
default="stretch",
638-
tooltip="Resize mode.",
639-
),
640-
]
641-
642-
@classmethod
643-
def _process(cls, image, width, height, mode):
644-
img = tensor_to_pil(image)
645-
646-
if mode == "stretch":
647-
img = img.resize((width, height), Image.Resampling.LANCZOS)
648-
elif mode == "crop_center":
649-
left = max(0, (img.width - width) // 2)
650-
top = max(0, (img.height - height) // 2)
651-
right = min(img.width, left + width)
652-
bottom = min(img.height, top + height)
653-
img = img.crop((left, top, right, bottom))
654-
if img.width != width or img.height != height:
655-
img = img.resize((width, height), Image.Resampling.LANCZOS)
656-
elif mode == "pad":
657-
img.thumbnail((width, height), Image.Resampling.LANCZOS)
658-
new_img = Image.new("RGB", (width, height), (0, 0, 0))
659-
paste_x = (width - img.width) // 2
660-
paste_y = (height - img.height) // 2
661-
new_img.paste(img, (paste_x, paste_y))
662-
img = new_img
663-
664-
return pil_to_tensor(img)
665-
666-
667-
class ResizeImagesToPixelCountNode(ImageProcessingNode):
668-
node_id = "ResizeImagesToPixelCount"
669-
display_name = "Resize Images to Pixel Count"
670-
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
671-
extra_inputs = [
672-
io.Int.Input(
673-
"pixel_count",
674-
default=512 * 512,
675-
min=1,
676-
max=8192 * 8192,
677-
tooltip="Target pixel count.",
678-
),
679-
io.Int.Input(
680-
"steps",
681-
default=64,
682-
min=1,
683-
max=128,
684-
tooltip="The stepping for resize width/height.",
685-
),
686-
]
687-
688-
@classmethod
689-
def _process(cls, image, pixel_count, steps):
690-
img = tensor_to_pil(image)
691-
w, h = img.size
692-
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
693-
new_w = int(w * pixel_count_ratio / steps) * steps
694-
new_h = int(h * pixel_count_ratio / steps) * steps
695-
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
696-
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
697-
return pil_to_tensor(img)
698-
699-
700626
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
701627
node_id = "ResizeImagesByShorterEdge"
702628
display_name = "Resize Images by Shorter Edge"
@@ -801,29 +727,6 @@ def _process(cls, image, width, height, seed):
801727
return pil_to_tensor(img)
802728

803729

804-
class FlipImagesNode(ImageProcessingNode):
805-
node_id = "FlipImages"
806-
display_name = "Flip Images"
807-
description = "Flip all images horizontally or vertically."
808-
extra_inputs = [
809-
io.Combo.Input(
810-
"direction",
811-
options=["horizontal", "vertical"],
812-
default="horizontal",
813-
tooltip="Flip direction.",
814-
),
815-
]
816-
817-
@classmethod
818-
def _process(cls, image, direction):
819-
img = tensor_to_pil(image)
820-
if direction == "horizontal":
821-
img = img.transpose(Image.FLIP_LEFT_RIGHT)
822-
else:
823-
img = img.transpose(Image.FLIP_TOP_BOTTOM)
824-
return pil_to_tensor(img)
825-
826-
827730
class NormalizeImagesNode(ImageProcessingNode):
828731
node_id = "NormalizeImages"
829732
display_name = "Normalize Images"
@@ -1470,7 +1373,7 @@ def execute(cls, folder_name):
14701373
shard_path = os.path.join(dataset_dir, shard_file)
14711374

14721375
with open(shard_path, "rb") as f:
1473-
shard_data = torch.load(f)
1376+
shard_data = torch.load(f, weights_only=True)
14741377

14751378
all_latents.extend(shard_data["latents"])
14761379
all_conditioning.extend(shard_data["conditioning"])
@@ -1496,13 +1399,10 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
14961399
SaveImageDataSetToFolderNode,
14971400
SaveImageTextDataSetToFolderNode,
14981401
# Image transform nodes
1499-
ResizeImagesToSameSizeNode,
1500-
ResizeImagesToPixelCountNode,
15011402
ResizeImagesByShorterEdgeNode,
15021403
ResizeImagesByLongerEdgeNode,
15031404
CenterCropImagesNode,
15041405
RandomCropImagesNode,
1505-
FlipImagesNode,
15061406
NormalizeImagesNode,
15071407
AdjustBrightnessNode,
15081408
AdjustContrastNode,

0 commit comments

Comments
 (0)