Skip to content

Commit 48d368f

Browse files
authored
Merge pull request #2268 from kohya-ss/sd3
merge sd3 to main
2 parents ae72efb + 3265f2e commit 48d368f

File tree

6 files changed

+151
-123
lines changed

6 files changed

+151
-123
lines changed

library/deepspeed_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
9696
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
9797
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
9898
)
99-
99+
100100
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
101101
if args.mixed_precision.lower() == "fp16":
102102
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
125125
class DeepSpeedWrapper(torch.nn.Module):
126126
def __init__(self, **kw_models) -> None:
127127
super().__init__()
128-
128+
129129
self.models = torch.nn.ModuleDict()
130-
131-
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
130+
131+
wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
132132

133133
for key, model in kw_models.items():
134134
if isinstance(model, list):
135135
model = torch.nn.ModuleList(model)
136-
136+
137137
if wrap_model_forward_with_torch_autocast:
138-
model = self.__wrap_model_with_torch_autocast(model)
139-
138+
model = self.__wrap_model_with_torch_autocast(model)
139+
140140
assert isinstance(
141141
model, torch.nn.Module
142142
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
@@ -151,7 +151,7 @@ def __wrap_model_with_torch_autocast(self, model):
151151
return model
152152

153153
def __wrap_model_forward_with_torch_autocast(self, model):
154-
154+
155155
assert hasattr(model, "forward"), f"model must have a forward method."
156156

157157
forward_fn = model.forward
@@ -161,20 +161,19 @@ def forward(*args, **kwargs):
161161
device_type = model.device.type
162162
except AttributeError:
163163
logger.warning(
164-
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
165-
"to determine the device_type for torch.autocast()."
166-
)
164+
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
165+
"to determine the device_type for torch.autocast()."
166+
)
167167
device_type = get_preferred_device().type
168168

169-
with torch.autocast(device_type = device_type):
169+
with torch.autocast(device_type=device_type):
170170
return forward_fn(*args, **kwargs)
171171

172172
model.forward = forward
173173
return model
174-
174+
175175
def get_models(self):
176176
return self.models
177-
178177

179178
ds_model = DeepSpeedWrapper(**models)
180179
return ds_model

library/lumina_models.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@
3434
try:
3535
from flash_attn import flash_attn_varlen_func
3636
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
37-
except:
37+
except ImportError:
3838
# flash_attn may not be available but it is not required
3939
pass
4040

4141
try:
4242
from sageattention import sageattn
43-
except:
43+
except ImportError:
4444
pass
4545

4646
try:
4747
from apex.normalization import FusedRMSNorm as RMSNorm
48-
except:
48+
except ImportError:
4949
import warnings
5050

5151
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
@@ -98,7 +98,7 @@ def forward(self, x: Tensor):
9898
x_dtype = x.dtype
9999
# To handle float8 we need to convert the tensor to float
100100
x = x.float()
101-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
101+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
102102
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
103103

104104

@@ -370,7 +370,7 @@ def forward(
370370
if self.use_sage_attn:
371371
# Handle GQA (Grouped Query Attention) if needed
372372
n_rep = self.n_local_heads // self.n_local_kv_heads
373-
if n_rep >= 1:
373+
if n_rep > 1:
374374
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
375375
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
376376

@@ -379,7 +379,7 @@ def forward(
379379
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
380380
else:
381381
n_rep = self.n_local_heads // self.n_local_kv_heads
382-
if n_rep >= 1:
382+
if n_rep > 1:
383383
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
384384
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
385385

@@ -456,51 +456,47 @@ def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_sca
456456
bsz = q.shape[0]
457457
seqlen = q.shape[1]
458458

459-
# Transpose tensors to match SageAttention's expected format (HND layout)
460-
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
461-
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
462-
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
463-
464-
# Handle masking for SageAttention
465-
# We need to filter out masked positions - this approach handles variable sequence lengths
466-
outputs = []
467-
for b in range(bsz):
468-
# Find valid token positions from the mask
469-
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
470-
if valid_indices.numel() == 0:
471-
# If all tokens are masked, create a zero output
472-
batch_output = torch.zeros(
473-
seqlen, self.n_local_heads, self.head_dim,
474-
device=q.device, dtype=q.dtype
475-
)
476-
else:
477-
# Extract only valid tokens for this batch
478-
batch_q = q_transposed[b, :, valid_indices, :]
479-
batch_k = k_transposed[b, :, valid_indices, :]
480-
batch_v = v_transposed[b, :, valid_indices, :]
481-
482-
# Run SageAttention on valid tokens only
459+
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
460+
q_transposed = q.permute(0, 2, 1, 3)
461+
k_transposed = k.permute(0, 2, 1, 3)
462+
v_transposed = v.permute(0, 2, 1, 3)
463+
464+
# Fast path: if all tokens are valid, run batched SageAttention directly
465+
if x_mask.all():
466+
output = sageattn(
467+
q_transposed, k_transposed, v_transposed,
468+
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
469+
)
470+
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
471+
output = output.permute(0, 2, 1, 3)
472+
else:
473+
# Slow path: per-batch loop to handle variable-length masking
474+
# SageAttention does not support attention masks natively
475+
outputs = []
476+
for b in range(bsz):
477+
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
478+
if valid_indices.numel() == 0:
479+
outputs.append(torch.zeros(
480+
seqlen, self.n_local_heads, self.head_dim,
481+
device=q.device, dtype=q.dtype,
482+
))
483+
continue
484+
483485
batch_output_valid = sageattn(
484-
batch_q.unsqueeze(0), # Add batch dimension back
485-
batch_k.unsqueeze(0),
486-
batch_v.unsqueeze(0),
487-
tensor_layout="HND",
488-
is_causal=False,
489-
sm_scale=softmax_scale
486+
q_transposed[b:b+1, :, valid_indices, :],
487+
k_transposed[b:b+1, :, valid_indices, :],
488+
v_transposed[b:b+1, :, valid_indices, :],
489+
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
490490
)
491-
492-
# Create output tensor with zeros for masked positions
491+
493492
batch_output = torch.zeros(
494-
seqlen, self.n_local_heads, self.head_dim,
495-
device=q.device, dtype=q.dtype
493+
seqlen, self.n_local_heads, self.head_dim,
494+
device=q.device, dtype=q.dtype,
496495
)
497-
# Place valid outputs back in the right positions
498496
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
499-
500-
outputs.append(batch_output)
501-
502-
# Stack batch outputs and reshape to expected format
503-
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
497+
outputs.append(batch_output)
498+
499+
output = torch.stack(outputs, dim=0)
504500
except NameError as e:
505501
raise RuntimeError(
506502
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
@@ -1113,10 +1109,9 @@ def patchify_and_embed(
11131109

11141110
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
11151111

1116-
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
1117-
for i in range(bsz):
1118-
x[i, :image_seq_len] = x[i]
1119-
x_mask[i, :image_seq_len] = True
1112+
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
1113+
# The mask can be set without a loop since all samples have the same image_seq_len.
1114+
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
11201115

11211116
x = self.x_embedder(x)
11221117

@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
13891384
axes_dims=[40, 40, 40],
13901385
axes_lens=[300, 512, 512],
13911386
**kwargs,
1392-
)
1387+
)

library/lumina_train_util.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -334,32 +334,35 @@ def sample_image_inference(
334334

335335
# No need to add system prompt here, as it has been handled in the tokenize_strategy
336336

337-
# Get sample prompts from cache
337+
# Get sample prompts from cache, fallback to live encoding
338+
gemma2_conds = None
339+
neg_gemma2_conds = None
340+
338341
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
339342
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
340343
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
341344

342-
if (
343-
sample_prompts_gemma2_outputs
344-
and negative_prompt in sample_prompts_gemma2_outputs
345-
):
345+
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
346346
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
347-
logger.info(
348-
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
349-
)
347+
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
350348

351-
# Load sample prompts from Gemma 2
352-
if gemma2_model is not None:
349+
# Only encode if not found in cache
350+
if gemma2_conds is None and gemma2_model is not None:
353351
tokens_and_masks = tokenize_strategy.tokenize(prompt)
354352
gemma2_conds = encoding_strategy.encode_tokens(
355353
tokenize_strategy, gemma2_model, tokens_and_masks
356354
)
357355

356+
if neg_gemma2_conds is None and gemma2_model is not None:
358357
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
359358
neg_gemma2_conds = encoding_strategy.encode_tokens(
360359
tokenize_strategy, gemma2_model, tokens_and_masks
361360
)
362361

362+
if gemma2_conds is None or neg_gemma2_conds is None:
363+
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
364+
continue
365+
363366
# Unpack Gemma2 outputs
364367
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
365368
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
@@ -475,6 +478,7 @@ def sample_image_inference(
475478

476479

477480
def time_shift(mu: float, sigma: float, t: torch.Tensor):
481+
"""Apply time shifting to timesteps."""
478482
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
479483
return t
480484

@@ -483,7 +487,7 @@ def get_lin_function(
483487
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
484488
) -> Callable[[float], float]:
485489
"""
486-
Get linear function
490+
Get linear function for resolution-dependent shifting.
487491
488492
Args:
489493
image_seq_len,
@@ -528,6 +532,7 @@ def get_schedule(
528532
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
529533
image_seq_len
530534
)
535+
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
531536
timesteps = time_shift(mu, 1.0, timesteps)
532537

533538
return timesteps.tolist()
@@ -689,15 +694,15 @@ def denoise(
689694

690695
img_dtype = img.dtype
691696

692-
if img.dtype != img_dtype:
693-
if torch.backends.mps.is_available():
694-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
695-
img = img.to(img_dtype)
696-
697697
# compute the previous noisy sample x_t -> x_t-1
698698
noise_pred = -noise_pred
699699
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
700700

701+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
702+
if img.dtype != img_dtype:
703+
if torch.backends.mps.is_available():
704+
img = img.to(img_dtype)
705+
701706
model.prepare_block_swap_before_forward()
702707
return img
703708

@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
823828
timesteps = sigmas * num_timesteps
824829
elif args.timestep_sampling == "nextdit_shift":
825830
sigmas = torch.rand((bsz,), device=device)
831+
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
826832
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
827833
sigmas = time_shift(mu, 1.0, sigmas)
828834

@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
831837
sigmas = torch.randn(bsz, device=device)
832838
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
833839
sigmas = sigmas.sigmoid()
840+
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
834841
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
835842
sigmas = time_shift(mu, 1.0, sigmas)
836843
timesteps = sigmas * num_timesteps

lumina_train.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,19 +370,25 @@ def train(args):
370370
grouped_params = []
371371
param_group = {}
372372
for group in params_to_optimize:
373-
named_parameters = list(nextdit.named_parameters())
373+
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
374374
assert len(named_parameters) == len(
375375
group["params"]
376-
), "number of parameters does not match"
376+
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
377377
for p, np in zip(group["params"], named_parameters):
378378
# determine target layer and block index for each parameter
379-
block_type = "other" # double, single or other
380-
if np[0].startswith("double_blocks"):
379+
# Lumina NextDiT architecture:
380+
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
381+
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
382+
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
383+
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
384+
block_type = "other"
385+
if np[0].startswith("layers."):
381386
block_index = int(np[0].split(".")[1])
382-
block_type = "double"
383-
elif np[0].startswith("single_blocks"):
384-
block_index = int(np[0].split(".")[1])
385-
block_type = "single"
387+
block_type = "main"
388+
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
389+
# All refiner blocks (context + noise) grouped together
390+
block_index = -1
391+
block_type = "refiner"
386392
else:
387393
block_index = -1
388394

@@ -759,7 +765,7 @@ def grad_hook(parameter: torch.Tensor):
759765

760766
# calculate loss
761767
huber_c = train_util.get_huber_threshold_if_needed(
762-
args, timesteps, noise_scheduler
768+
args, 1000 - timesteps, noise_scheduler
763769
)
764770
loss = train_util.conditional_loss(
765771
model_pred.float(), target.float(), args.loss_type, "none", huber_c

0 commit comments

Comments
 (0)