Skip to content

Commit 275041d

Browse files
committed
update
1 parent ccc1b36 commit 275041d

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def retrieve_timesteps(
139139

140140
class MochiPipeline(
141141
DiffusionPipeline,
142+
TextualInversionLoaderMixin
142143
):
143144
r"""
144145
The Flux pipeline for text-to-image generation.
@@ -187,14 +188,17 @@ def __init__(
187188
transformer=transformer,
188189
scheduler=scheduler,
189190
)
190-
self.vae_scale_factor = (
191-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
192-
)
191+
#TODO: determine these scaling factors from model parameters
192+
self.vae_spatial_scale_factor = 8
193+
self.vae_temporal_scale_factor = 6
194+
self.patch_size = 2
195+
193196
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
194197
self.tokenizer_max_length = (
195198
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
196199
)
197-
self.default_sample_size = 64
200+
self.default_height = 64
201+
self.default_width = 64
198202

199203
def _get_t5_prompt_embeds(
200204
self,
@@ -235,7 +239,7 @@ def _get_t5_prompt_embeds(
235239
f" {max_sequence_length} tokens: {removed_text}"
236240
)
237241

238-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0]
242+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state
239243

240244
dtype = self.text_encoder.dtype
241245
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
@@ -246,7 +250,32 @@ def _get_t5_prompt_embeds(
246250
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
247251
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
248252

249-
return prompt_embeds
253+
return prompt_embeds, prompt_attention_mask
254+
255+
def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim):
256+
N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2)
257+
258+
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
259+
assert N > 0 and len(attention_mask) == 1
260+
attention_mask = attention_mask[0]
261+
262+
mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L)
263+
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
264+
valid_token_indices = torch.nonzero(
265+
mask.flatten(), as_tuple=False
266+
).flatten() # up to (B * (N + L),)
267+
268+
assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,)
269+
cu_seqlens = F.pad(
270+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
271+
)
272+
max_seqlen_in_batch = seqlens_in_batch.max().item()
273+
274+
return {
275+
"cu_seqlens_kv": cu_seqlens,
276+
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
277+
"valid_token_indices_kv": valid_token_indices,
278+
}
250279

251280
def encode_prompt(
252281
self,

0 commit comments

Comments
 (0)