Skip to content

Commit 7452974

Browse files
committed
Merge branch 'sd3' into po
2 parents 46414bb + 5a18a03 commit 7452974

32 files changed

+1784
-931
lines changed

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@ The command to install PyTorch is as follows:
1414

1515
### Recent Updates
1616

17+
Apr 6, 2025:
18+
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
19+
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
20+
21+
Mar 30, 2025:
22+
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
23+
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
24+
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
25+
26+
Mar 20, 2025:
27+
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
28+
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
29+
30+
Mar 6, 2025:
31+
32+
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
33+
34+
Feb 26, 2025:
35+
36+
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
37+
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
38+
1739
Jan 25, 2025:
1840

1941
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
@@ -739,6 +761,8 @@ Not available yet.
739761
[__Change History__](#change-history) is moved to the bottom of the page.
740762
更新履歴は[ページ末尾](#change-history)に移しました。
741763

764+
Latest update: 2025-03-21 (Version 0.9.1)
765+
742766
[日本語版READMEはこちら](./README-ja.md)
743767

744768
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
@@ -882,6 +906,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
882906

883907
## Change History
884908

909+
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
910+
911+
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
912+
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
913+
885914
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
886915

887916
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.

docs/config_README-en.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ These options are related to subset configuration.
152152
| `keep_tokens_separator` | `“|||”` | o | o | o |
153153
| `secondary_separator` | `“;;;”` | o | o | o |
154154
| `enable_wildcard` | `true` | o | o | o |
155+
| `resize_interpolation` | (not specified) | o | o | o |
155156

156157
* `num_repeats`
157158
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
@@ -165,6 +166,8 @@ These options are related to subset configuration.
165166
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
166167
* `enable_wildcard`
167168
* Enables wildcard notation. This will be explained later.
169+
* `resize_interpolation`
170+
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
168171

169172
### DreamBooth-specific options
170173

docs/config_README-ja.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
144144
| `keep_tokens_separator` | `“|||”` | o | o | o |
145145
| `secondary_separator` | `“;;;”` | o | o | o |
146146
| `enable_wildcard` | `true` | o | o | o |
147+
| `resize_interpolation` |(通常は設定しません) | o | o | o |
147148

148149
* `num_repeats`
149150
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
162163
* `enable_wildcard`
163164
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
164165

166+
* `resize_interpolation`
167+
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos``box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
168+
165169
### DreamBooth 方式専用のオプション
166170

167171
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。

finetune/tag_images_by_wd14_tagger.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tqdm import tqdm
1212

1313
import library.train_util as train_util
14-
from library.utils import setup_logging, pil_resize
14+
from library.utils import setup_logging, resize_image
1515

1616
setup_logging()
1717
import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
4242
pad_t = pad_y // 2
4343
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
4444

45-
if size > IMAGE_SIZE:
46-
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
47-
else:
48-
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
45+
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
4946

5047
image = image.astype(np.float32)
5148
return image

flux_train_network.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ def __init__(self):
3636
self.is_schnell: Optional[bool] = None
3737
self.is_swapping_blocks: bool = False
3838

39-
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
39+
def assert_extra_args(
40+
self,
41+
args,
42+
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
43+
val_dataset_group: Optional[train_util.DatasetGroup],
44+
):
4045
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
4146
# sdxl_train_util.verify_sdxl_training_args(args)
4247

@@ -323,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) ->
323328
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
324329
return noise_scheduler
325330

326-
def encode_images_to_latents(self, args, accelerator, vae, images):
331+
def encode_images_to_latents(self, args, vae, images):
327332
return vae.encode(images)
328333

329334
def shift_scale_latents(self, args, latents):
@@ -341,7 +346,7 @@ def get_noise_pred_and_target(
341346
network,
342347
weight_dtype,
343348
train_unet,
344-
is_train=True
349+
is_train=True,
345350
):
346351
# Sample noise that we'll add to the latents
347352
noise = torch.randn_like(latents)
@@ -376,8 +381,7 @@ def get_noise_pred_and_target(
376381
t5_attn_mask = None
377382

378383
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379-
# if not args.split_mode:
380-
# normal forward
384+
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
381385
with torch.set_grad_enabled(is_train), accelerator.autocast():
382386
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383387
model_pred = unet(
@@ -390,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
390394
guidance=guidance_vec,
391395
txt_attention_mask=t5_attn_mask,
392396
)
393-
"""
394-
else:
395-
# split forward to reduce memory usage
396-
assert network.train_blocks == "single", "train_blocks must be single for split mode"
397-
with accelerator.autocast():
398-
# move flux lower to cpu, and then move flux upper to gpu
399-
unet.to("cpu")
400-
clean_memory_on_device(accelerator.device)
401-
self.flux_upper.to(accelerator.device)
402-
403-
# upper model does not require grad
404-
with torch.no_grad():
405-
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
406-
img=packed_noisy_model_input,
407-
img_ids=img_ids,
408-
txt=t5_out,
409-
txt_ids=txt_ids,
410-
y=l_pooled,
411-
timesteps=timesteps / 1000,
412-
guidance=guidance_vec,
413-
txt_attention_mask=t5_attn_mask,
414-
)
415-
416-
# move flux upper back to cpu, and then move flux lower to gpu
417-
self.flux_upper.to("cpu")
418-
clean_memory_on_device(accelerator.device)
419-
unet.to(accelerator.device)
420-
421-
# lower model requires grad
422-
intermediate_img.requires_grad_(True)
423-
intermediate_txt.requires_grad_(True)
424-
vec.requires_grad_(True)
425-
pe.requires_grad_(True)
426-
427-
with torch.set_grad_enabled(is_train and train_unet):
428-
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
429-
"""
430-
431397
return model_pred
432398

433399
model_pred = call_dit(
@@ -546,6 +512,11 @@ def forward(hidden_states):
546512
text_encoder.to(te_weight_dtype) # fp8
547513
prepare_fp8(text_encoder, weight_dtype)
548514

515+
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
516+
if self.is_swapping_blocks:
517+
# prepare for next forward: because backward pass is not called, we need to prepare it here
518+
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
519+
549520
def prepare_unet_with_accelerator(
550521
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
551522
) -> torch.nn.Module:

library/config_util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class BaseSubsetParams:
7575
custom_attributes: Optional[Dict[str, Any]] = None
7676
validation_seed: int = 0
7777
validation_split: float = 0.0
78+
resize_interpolation: Optional[str] = None
7879
preference: bool = False
7980
preference_caption_prefix: Optional[str] = None
8081
preference_caption_suffix: Optional[str] = None
@@ -111,7 +112,7 @@ class BaseDatasetParams:
111112
debug_dataset: bool = False
112113
validation_seed: Optional[int] = None
113114
validation_split: float = 0.0
114-
115+
resize_interpolation: Optional[str] = None
115116

116117
@dataclass
117118
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -201,6 +202,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
201202
"caption_prefix": str,
202203
"caption_suffix": str,
203204
"custom_attributes": dict,
205+
"resize_interpolation": str,
204206
"preference": bool,
205207
"preference_caption_prefix": str,
206208
"preference_caption_suffix": str,
@@ -251,6 +253,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
251253
"validation_split": float,
252254
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
253255
"network_multiplier": float,
256+
"resize_interpolation": str,
254257
}
255258

256259
# options handled by argparse but not handled by user config
@@ -535,6 +538,7 @@ def print_info(_datasets, dataset_type: str):
535538
[{dataset_type} {i}]
536539
batch_size: {dataset.batch_size}
537540
resolution: {(dataset.width, dataset.height)}
541+
resize_interpolation: {dataset.resize_interpolation}
538542
enable_bucket: {dataset.enable_bucket}
539543
""")
540544

@@ -568,6 +572,7 @@ def print_info(_datasets, dataset_type: str):
568572
token_warmup_min: {subset.token_warmup_min},
569573
token_warmup_step: {subset.token_warmup_step},
570574
alpha_mask: {subset.alpha_mask}
575+
resize_interpolation: {subset.resize_interpolation}
571576
custom_attributes: {subset.custom_attributes}
572577
"""), " ")
573578

library/device_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
import gc
33

44
import torch
5+
try:
6+
# intel gpu support for pytorch older than 2.5
7+
# ipex is not needed after pytorch 2.5
8+
import intel_extension_for_pytorch as ipex # noqa
9+
except Exception:
10+
pass
11+
512

613
try:
714
HAS_CUDA = torch.cuda.is_available()
@@ -14,8 +21,6 @@
1421
HAS_MPS = False
1522

1623
try:
17-
import intel_extension_for_pytorch as ipex # noqa
18-
1924
HAS_XPU = torch.xpu.is_available()
2025
except Exception:
2126
HAS_XPU = False
@@ -69,7 +74,7 @@ def init_ipex():
6974
7075
This function should run right after importing torch and before doing anything else.
7176
72-
If IPEX is not available, this function does nothing.
77+
If xpu is not available, this function does nothing.
7378
"""
7479
try:
7580
if HAS_XPU:

library/flux_train_utils.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
366366
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
367367

368368
sigma = sigmas[step_indices].flatten()
369-
while len(sigma.shape) < n_dim:
370-
sigma = sigma.unsqueeze(-1)
371369
return sigma
372370

373371

@@ -410,42 +408,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
410408

411409

412410
def get_noisy_model_input_and_timesteps(
413-
args, noise_scheduler, latents, noise, device, dtype
411+
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
414412
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
415413
bsz, _, h, w = latents.shape
416-
sigmas = None
417-
414+
assert bsz > 0, "Batch size not large enough"
415+
num_timesteps = noise_scheduler.config.num_train_timesteps
418416
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
419-
# Simple random t-based noise sampling
417+
# Simple random sigma-based noise sampling
420418
if args.timestep_sampling == "sigmoid":
421419
# https://github.com/XLabs-AI/x-flux/tree/main
422-
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
420+
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
423421
else:
424-
t = torch.rand((bsz,), device=device)
422+
sigmas = torch.rand((bsz,), device=device)
425423

426-
timesteps = t * 1000.0
427-
t = t.view(-1, 1, 1, 1)
428-
noisy_model_input = (1 - t) * latents + t * noise
424+
timesteps = sigmas * num_timesteps
429425
elif args.timestep_sampling == "shift":
430426
shift = args.discrete_flow_shift
431-
logits_norm = torch.randn(bsz, device=device)
432-
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
433-
timesteps = logits_norm.sigmoid()
434-
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
435-
436-
t = timesteps.view(-1, 1, 1, 1)
437-
timesteps = timesteps * 1000.0
438-
noisy_model_input = (1 - t) * latents + t * noise
427+
sigmas = torch.randn(bsz, device=device)
428+
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
429+
sigmas = sigmas.sigmoid()
430+
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
431+
timesteps = sigmas * num_timesteps
439432
elif args.timestep_sampling == "flux_shift":
440-
logits_norm = torch.randn(bsz, device=device)
441-
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
442-
timesteps = logits_norm.sigmoid()
443-
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
444-
timesteps = time_shift(mu, 1.0, timesteps)
445-
446-
t = timesteps.view(-1, 1, 1, 1)
447-
timesteps = timesteps * 1000.0
448-
noisy_model_input = (1 - t) * latents + t * noise
433+
sigmas = torch.randn(bsz, device=device)
434+
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
435+
sigmas = sigmas.sigmoid()
436+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
437+
sigmas = time_shift(mu, 1.0, sigmas)
438+
timesteps = sigmas * num_timesteps
449439
else:
450440
# Sample a random timestep for each image
451441
# for weighting schemes where we sample timesteps non-uniformly
@@ -456,12 +446,24 @@ def get_noisy_model_input_and_timesteps(
456446
logit_std=args.logit_std,
457447
mode_scale=args.mode_scale,
458448
)
459-
indices = (u * noise_scheduler.config.num_train_timesteps).long()
449+
indices = (u * num_timesteps).long()
460450
timesteps = noise_scheduler.timesteps[indices].to(device=device)
461-
462-
# Add noise according to flow matching.
463451
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
464-
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
452+
453+
# Broadcast sigmas to latent shape
454+
sigmas = sigmas.view(-1, 1, 1, 1)
455+
456+
# Add noise to the latents according to the noise magnitude at each timestep
457+
# (this is the forward diffusion process)
458+
if args.ip_noise_gamma:
459+
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
460+
if args.ip_noise_gamma_random_strength:
461+
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
462+
else:
463+
ip_noise_gamma = args.ip_noise_gamma
464+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
465+
else:
466+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
465467

466468
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
467469

0 commit comments

Comments
 (0)