Skip to content

Commit 1fdae85

Browse files
committed
update
1 parent 6b9fd09 commit 1fdae85

21 files changed

+120
-107
lines changed

src/diffusers/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def forward(
449449
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
450450
elif self.norm_type == "ada_norm_single":
451451
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
452-
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
452+
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
453453
).chunk(6, dim=1)
454454
norm_hidden_states = self.norm1(hidden_states)
455455
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

src/diffusers/models/autoencoders/autoencoder_asym_kl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
6060
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
6161
"""
6262

63+
_always_upcast_modules = ["MaskConditionDecoder"]
64+
6365
@register_to_config
6466
def __init__(
6567
self,

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
7070

7171
_supports_gradient_checkpointing = True
7272
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
73+
_always_upcast_modules = ["Decoder"]
7374

7475
@register_to_config
7576
def __init__(

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
192192
"""
193193

194194
_supports_gradient_checkpointing = True
195+
_always_upcast_modules = ["TemporalDecoder"]
195196

196197
@register_to_config
197198
def __init__(

src/diffusers/models/autoencoders/autoencoder_oobleck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
317317
"""
318318

319319
_supports_gradient_checkpointing = False
320+
_always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"]
320321

321322
@register_to_config
322323
def __init__(

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def decode(
330330
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
331331
332332
"""
333-
z = (z * self.config.scaling_factor - self.means) / self.stds
333+
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
334334

335335
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
336336
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)

src/diffusers/models/autoencoders/vq_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
7171
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
7272
"""
7373

74+
_always_upcast_modules = ["Decoder", "VectorQuantizer"]
75+
7476
@register_to_config
7577
def __init__(
7678
self,

src/diffusers/models/modeling_utils.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,28 +264,61 @@ def disable_xformers_memory_efficient_attention(self) -> None:
264264
self.set_use_memory_efficient_attention_xformers(False)
265265

266266
def enable_layerwise_upcasting(self, upcast_dtype=None):
267+
r"""
268+
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
269+
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
270+
individual modules in the model to the appropriate dtype right before the foward pass.
271+
272+
The module is then moved back to the low memory dtype after the foward pass.
273+
"""
274+
267275
upcast_dtype = upcast_dtype or torch.float32
268-
downcast_dtype = self.dtype
276+
original_dtype = self.dtype
269277

270-
def upcast_hook_fn(module):
278+
def upcast_dtype_hook_fn(module, *args, **kwargs):
271279
module = module.to(upcast_dtype)
272280

273-
def downcast_hook_fn(module):
274-
module = module.to(downcast_dtype)
281+
def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
282+
module = module.to(original_dtype)
275283

276284
def fn_recursive_upcast(module):
285+
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks.
286+
287+
e.g.
288+
289+
```
290+
class MyModel(nn.Module):
291+
def forward(self, x):
292+
dtype = next(iter(self.blocks.parameters())).dtype
293+
x = self.blocks(x) + torch.ones(x.size()).to(dtype)
294+
```
295+
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
296+
their `forward` method is called. We need to add the upcast hook on the entire module in order for the
297+
operation to work.
298+
299+
The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
300+
entirely, rather than layerwise.
301+
302+
"""
303+
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
304+
# Upcast entire module and exist recursion
305+
module.register_forward_pre_hook(upcast_dtype_hook_fn)
306+
module.register_forward_hook(cast_to_original_dtype_hook_fn)
307+
308+
return
309+
277310
has_children = list(module.children())
278311
if not has_children:
279-
module.register_forward_pre_hook(upcast_hook_fn)
280-
module.register_forward_hook(downcast_hook_fn)
312+
module.register_forward_pre_hook(upcast_dtype_hook_fn)
313+
module.register_forward_hook(cast_to_original_dtype_hook_fn)
281314

282315
for child in module.children():
283316
fn_recursive_upcast(child)
284317

285318
for module in self.children():
286319
fn_recursive_upcast(module)
287320

288-
def disable_dynamic_upcasting(self):
321+
def disable_layerwise_upcasting(self):
289322
def fn_recursive_upcast(module):
290323
has_children = list(module.children())
291324
if not has_children:

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
259259
"""
260260

261261
_supports_gradient_checkpointing = True
262+
_always_upcast_modules = ["AuraFlowPatchEmbed"]
262263

263264
@register_to_config
264265
def __init__(
@@ -440,11 +441,15 @@ def forward(
440441

441442
# Apply patch embedding, timestep embedding, and project the caption embeddings.
442443
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
443-
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
444+
temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype)
444445
temb = self.time_step_proj(temb)
445446
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
446447
encoder_hidden_states = torch.cat(
447-
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
448+
[
449+
self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1),
450+
encoder_hidden_states,
451+
],
452+
dim=1,
448453
)
449454

450455
# MMDiT blocks.

src/diffusers/models/transformers/dit_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
6565
"""
6666

6767
_supports_gradient_checkpointing = True
68+
_always_upcast_modules = ["PatchEmbed"]
6869

6970
@register_to_config
7071
def __init__(

0 commit comments

Comments
 (0)