diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md index 1788e07a21d6..7a7b4841125f 100644 --- a/examples/controlnet/README_sd3.md +++ b/examples/controlnet/README_sd3.md @@ -104,7 +104,7 @@ from diffusers.utils import load_image import torch base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers" -controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet" +controlnet_path = "DavyMorgan/sd3-controlnet-out" controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) pipe = StableDiffusion3ControlNetPipeline.from_pretrained( diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 4fae8a072c6f..dbe41578dc09 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -50,7 +50,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -64,17 +64,6 @@ logger = get_logger(__name__) -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): logger.info("Running validation... ") @@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" model_description = f"""