@@ -103,28 +103,28 @@ def calculate_shift(
103103 return mu
104104
105105
106- # Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents
106+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline. _pack_latents
107107def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
108- latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
109- latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110- latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
108+ latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
109+ latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110+ latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
111111
112- return latents
112+ return latents
113113
114114
115- # Copied from diffusers.pipelines.flux.pipeline_flux._prepare_latent_image_ids
115+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline. _prepare_latent_image_ids
116116def _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
117- latent_image_ids = torch .zeros (height , width , 3 )
118- latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
119- latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
117+ latent_image_ids = torch .zeros (height , width , 3 )
118+ latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
119+ latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
120120
121- latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
121+ latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
122122
123- latent_image_ids = latent_image_ids .reshape (
124- latent_image_id_height * latent_image_id_width , latent_image_id_channels
125- )
123+ latent_image_ids = latent_image_ids .reshape (
124+ latent_image_id_height * latent_image_id_width , latent_image_id_channels
125+ )
126126
127- return latent_image_ids .to (device = device , dtype = dtype )
127+ return latent_image_ids .to (device = device , dtype = dtype )
128128
129129
130130class FluxInputStep (PipelineBlock ):
0 commit comments