|
50 | 50 | ) |
51 | 51 | from diffusers.optimization import get_scheduler |
52 | 52 | from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory |
53 | | -from diffusers.utils import check_min_version, is_wandb_available |
| 53 | +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid |
54 | 54 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
55 | 55 | from diffusers.utils.torch_utils import is_compiled_module |
56 | 56 |
|
|
64 | 64 | logger = get_logger(__name__) |
65 | 65 |
|
66 | 66 |
|
67 | | -def image_grid(imgs, rows, cols): |
68 | | - assert len(imgs) == rows * cols |
69 | | - |
70 | | - w, h = imgs[0].size |
71 | | - grid = Image.new("RGB", size=(cols * w, rows * h)) |
72 | | - |
73 | | - for i, img in enumerate(imgs): |
74 | | - grid.paste(img, box=(i % cols * w, i // cols * h)) |
75 | | - return grid |
76 | | - |
77 | | - |
78 | 67 | def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): |
79 | 68 | logger.info("Running validation... ") |
80 | 69 |
|
@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N |
224 | 213 | validation_image.save(os.path.join(repo_folder, "image_control.png")) |
225 | 214 | img_str += f"prompt: {validation_prompt}\n" |
226 | 215 | images = [validation_image] + images |
227 | | - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
| 216 | + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) |
228 | 217 | img_str += f"\n" |
229 | 218 |
|
230 | 219 | model_description = f""" |
|
0 commit comments