Skip to content

Commit f42fe8c

Browse files
authored
Merge branch 'main' into redux
2 parents 971b376 + 8a450c3 commit f42fe8c

17 files changed

+1874
-1277
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def create_forward(*inputs):
433433
hidden_states,
434434
temb,
435435
zq,
436-
conv_cache=conv_cache.get(conv_cache_key),
436+
conv_cache.get(conv_cache_key),
437437
)
438438
else:
439439
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -531,7 +531,7 @@ def create_forward(*inputs):
531531
return create_forward
532532

533533
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
534+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535535
)
536536
else:
537537
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -649,7 +649,7 @@ def create_forward(*inputs):
649649
hidden_states,
650650
temb,
651651
zq,
652-
conv_cache=conv_cache.get(conv_cache_key),
652+
conv_cache.get(conv_cache_key),
653653
)
654654
else:
655655
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -789,7 +789,7 @@ def custom_forward(*inputs):
789789
hidden_states,
790790
temb,
791791
None,
792-
conv_cache=conv_cache.get(conv_cache_key),
792+
conv_cache.get(conv_cache_key),
793793
)
794794

795795
# 2. Mid
@@ -798,14 +798,14 @@ def custom_forward(*inputs):
798798
hidden_states,
799799
temb,
800800
None,
801-
conv_cache=conv_cache.get("mid_block"),
801+
conv_cache.get("mid_block"),
802802
)
803803
else:
804804
# 1. Down
805805
for i, down_block in enumerate(self.down_blocks):
806806
conv_cache_key = f"down_block_{i}"
807807
hidden_states, new_conv_cache[conv_cache_key] = down_block(
808-
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
808+
hidden_states, temb, None, conv_cache.get(conv_cache_key)
809809
)
810810

811811
# 2. Mid
@@ -953,7 +953,7 @@ def custom_forward(*inputs):
953953
hidden_states,
954954
temb,
955955
sample,
956-
conv_cache=conv_cache.get("mid_block"),
956+
conv_cache.get("mid_block"),
957957
)
958958

959959
# 2. Up
@@ -964,7 +964,7 @@ def custom_forward(*inputs):
964964
hidden_states,
965965
temb,
966966
sample,
967-
conv_cache=conv_cache.get(conv_cache_key),
967+
conv_cache.get(conv_cache_key),
968968
)
969969
else:
970970
# 1. Mid
@@ -1476,7 +1476,7 @@ def forward(
14761476
z = posterior.sample(generator=generator)
14771477
else:
14781478
z = posterior.mode()
1479-
dec = self.decode(z)
1479+
dec = self.decode(z).sample
14801480
if not return_dict:
14811481
return (dec,)
1482-
return dec
1482+
return DecoderOutput(sample=dec)

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,6 @@ def __init__(
229229

230230
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
231231

232-
sample_size = (
233-
self.config.sample_size[0]
234-
if isinstance(self.config.sample_size, (list, tuple))
235-
else self.config.sample_size
236-
)
237-
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
238-
self.tile_overlap_factor = 0.25
239-
240232
def _set_gradient_checkpointing(self, module, value=False):
241233
if isinstance(module, (Encoder, TemporalDecoder)):
242234
module.gradient_checkpointing = value

src/diffusers/models/autoencoders/autoencoder_tiny.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def decode(
310310
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
311311
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
312312
if self.use_slicing and x.shape[0] > 1:
313-
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
313+
output = [
314+
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
315+
]
314316
output = torch.cat(output)
315317
else:
316318
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
@@ -341,7 +343,7 @@ def forward(
341343
# as if we were loading the latents from an RGBA uint8 image.
342344
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
343345

344-
dec = self.decode(unscaled_enc)
346+
dec = self.decode(unscaled_enc).sample
345347

346348
if not return_dict:
347349
return (dec,)

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ...models.embeddings import get_1d_rotary_pos_embed
2727
from ...schedulers import EDMDPMSolverMultistepScheduler
2828
from ...utils import (
29+
is_torch_xla_available,
2930
logging,
3031
replace_example_docstring,
3132
)
@@ -34,6 +35,13 @@
3435
from .modeling_stable_audio import StableAudioProjectionModel
3536

3637

38+
if is_torch_xla_available():
39+
import torch_xla.core.xla_model as xm
40+
41+
XLA_AVAILABLE = True
42+
else:
43+
XLA_AVAILABLE = False
44+
3745
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3846

3947
EXAMPLE_DOC_STRING = """
@@ -726,6 +734,9 @@ def __call__(
726734
step_idx = i // getattr(self.scheduler, "order", 1)
727735
callback(step_idx, t, latents)
728736

737+
if XLA_AVAILABLE:
738+
xm.mark_step()
739+
729740
# 9. Post-processing
730741
if not output_type == "latent":
731742
audio = self.vae.decode(latents).sample

0 commit comments

Comments
 (0)