Skip to content

Commit 4965189

Browse files
authored
Merge pull request #1903 from kohya-ss/val-loss-improvement
Val loss improvement
2 parents ae409e8 + 1fcac98 commit 4965189

File tree

4 files changed

+305
-241
lines changed

4 files changed

+305
-241
lines changed

flux_train_network.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,7 @@ def get_noise_pred_and_target(
381381
t5_attn_mask = None
382382

383383
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
384-
# if not args.split_mode:
385-
# normal forward
384+
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
386385
with torch.set_grad_enabled(is_train), accelerator.autocast():
387386
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
388387
model_pred = unet(
@@ -395,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
395394
guidance=guidance_vec,
396395
txt_attention_mask=t5_attn_mask,
397396
)
398-
"""
399-
else:
400-
# split forward to reduce memory usage
401-
assert network.train_blocks == "single", "train_blocks must be single for split mode"
402-
with accelerator.autocast():
403-
# move flux lower to cpu, and then move flux upper to gpu
404-
unet.to("cpu")
405-
clean_memory_on_device(accelerator.device)
406-
self.flux_upper.to(accelerator.device)
407-
408-
# upper model does not require grad
409-
with torch.no_grad():
410-
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
411-
img=packed_noisy_model_input,
412-
img_ids=img_ids,
413-
txt=t5_out,
414-
txt_ids=txt_ids,
415-
y=l_pooled,
416-
timesteps=timesteps / 1000,
417-
guidance=guidance_vec,
418-
txt_attention_mask=t5_attn_mask,
419-
)
420-
421-
# move flux upper back to cpu, and then move flux lower to gpu
422-
self.flux_upper.to("cpu")
423-
clean_memory_on_device(accelerator.device)
424-
unet.to(accelerator.device)
425-
426-
# lower model requires grad
427-
intermediate_img.requires_grad_(True)
428-
intermediate_txt.requires_grad_(True)
429-
vec.requires_grad_(True)
430-
pe.requires_grad_(True)
431-
432-
with torch.set_grad_enabled(is_train and train_unet):
433-
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
434-
"""
435-
436397
return model_pred
437398

438399
model_pred = call_dit(
@@ -551,6 +512,11 @@ def forward(hidden_states):
551512
text_encoder.to(te_weight_dtype) # fp8
552513
prepare_fp8(text_encoder, weight_dtype)
553514

515+
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
516+
if self.is_swapping_blocks:
517+
# prepare for next forward: because backward pass is not called, we need to prepare it here
518+
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
519+
554520
def prepare_unet_with_accelerator(
555521
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
556522
) -> torch.nn.Module:

library/train_util.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,7 @@
1313
import shutil
1414
import time
1515
import typing
16-
from typing import (
17-
Any,
18-
Callable,
19-
Dict,
20-
List,
21-
NamedTuple,
22-
Optional,
23-
Sequence,
24-
Tuple,
25-
Union
26-
)
16+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
2717
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
2818
import glob
2919
import math
@@ -146,12 +136,13 @@
146136
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
147137
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
148138

139+
149140
def split_train_val(
150-
paths: List[str],
141+
paths: List[str],
151142
sizes: List[Optional[Tuple[int, int]]],
152-
is_training_dataset: bool,
153-
validation_split: float,
154-
validation_seed: int | None
143+
is_training_dataset: bool,
144+
validation_split: float,
145+
validation_seed: int | None,
155146
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
156147
"""
157148
Split the dataset into train and validation
@@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
18421833
class DreamBoothDataset(BaseDataset):
18431834
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
18441835

1845-
# The is_training_dataset defines the type of dataset, training or validation
1836+
# The is_training_dataset defines the type of dataset, training or validation
18461837
# if is_training_dataset is True -> training dataset
18471838
# if is_training_dataset is False -> validation dataset
18481839
def __init__(
@@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
19811972
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
19821973

19831974
# We want to create a training and validation split. This should be improved in the future
1984-
# to allow a clearer distinction between training and validation. This can be seen as a
1975+
# to allow a clearer distinction between training and validation. This can be seen as a
19851976
# short-term solution to limit what is necessary to implement validation datasets
1986-
#
1977+
#
19871978
# We split the dataset for the subset based on if we are doing a validation split
1988-
# The self.is_training_dataset defines the type of dataset, training or validation
1979+
# The self.is_training_dataset defines the type of dataset, training or validation
19891980
# if self.is_training_dataset is True -> training dataset
19901981
# if self.is_training_dataset is False -> validation dataset
19911982
if self.validation_split > 0.0:
1992-
# For regularization images we do not want to split this dataset.
1983+
# For regularization images we do not want to split this dataset.
19931984
if subset.is_reg is True:
19941985
# Skip any validation dataset for regularization images
19951986
if self.is_training_dataset is False:
19961987
img_paths = []
19971988
sizes = []
1998-
# Otherwise the img_paths remain as original img_paths and no split
1989+
# Otherwise the img_paths remain as original img_paths and no split
19991990
# required for training images dataset of regularization images
20001991
else:
20011992
img_paths, sizes = split_train_val(
2002-
img_paths,
2003-
sizes,
2004-
self.is_training_dataset,
2005-
self.validation_split,
2006-
self.validation_seed
1993+
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
20071994
)
20081995

20091996
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -2373,7 +2360,7 @@ def __init__(
23732360
bucket_no_upscale: bool,
23742361
debug_dataset: bool,
23752362
validation_split: float,
2376-
validation_seed: Optional[int],
2363+
validation_seed: Optional[int],
23772364
) -> None:
23782365
super().__init__(resolution, network_multiplier, debug_dataset)
23792366

@@ -2431,9 +2418,9 @@ def __init__(
24312418
self.image_data = self.dreambooth_dataset_delegate.image_data
24322419
self.batch_size = batch_size
24332420
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
2434-
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
2421+
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
24352422
self.validation_split = validation_split
2436-
self.validation_seed = validation_seed
2423+
self.validation_seed = validation_seed
24372424

24382425
# assert all conditioning data exists
24392426
missing_imgs = []
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
59445931

59455932

59465933
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
5947-
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
5934+
if min_timestep < max_timestep:
5935+
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
5936+
else:
5937+
timesteps = torch.full((b_size,), max_timestep, device="cpu")
59485938
timesteps = timesteps.long().to(device)
59495939
return timesteps
59505940

59515941

5952-
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
5942+
def get_noise_noisy_latents_and_timesteps(
5943+
args, noise_scheduler, latents: torch.FloatTensor
5944+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
59535945
# Sample noise that we'll add to the latents
59545946
noise = torch.randn_like(latents, device=latents.device)
59555947
if args.noise_offset:
@@ -6441,7 +6433,7 @@ def sample_image_inference(
64416433
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
64426434

64436435

6444-
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
6436+
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
64456437
"""
64466438
Initialize experiment trackers with tracker specific behaviors
64476439
"""
@@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
64586450
)
64596451

64606452
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
6461-
import wandb
6453+
import wandb
6454+
64626455
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
64636456

64646457
# Define specific metrics to handle validation and epochs "steps"
64656458
wandb_tracker.define_metric("epoch", hidden=True)
64666459
wandb_tracker.define_metric("val_step", hidden=True)
64676460

6461+
wandb_tracker.define_metric("global_step", hidden=True)
6462+
6463+
64686464
# endregion
64696465

64706466

sd3_train_network.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,14 +450,19 @@ def forward(hidden_states):
450450
text_encoder.to(te_weight_dtype) # fp8
451451
prepare_fp8(text_encoder, weight_dtype)
452452

453-
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
454-
# drop cached text encoder outputs
453+
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
454+
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
455455
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
456456
if text_encoder_outputs_list is not None:
457457
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
458458
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
459459
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
460460

461+
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
462+
if self.is_swapping_blocks:
463+
# prepare for next forward: because backward pass is not called, we need to prepare it here
464+
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
465+
461466
def prepare_unet_with_accelerator(
462467
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
463468
) -> torch.nn.Module:

0 commit comments

Comments
 (0)