@@ -177,7 +177,6 @@ def __init__(
177177 def _get_glm_embeds (
178178 self ,
179179 prompt : Union [str , List [str ]] = None ,
180- num_images_per_prompt : int = 1 ,
181180 max_sequence_length : int = 1024 ,
182181 device : Optional [torch .device ] = None ,
183182 dtype : Optional [torch .dtype ] = None ,
@@ -186,7 +185,6 @@ def _get_glm_embeds(
186185 dtype = dtype or self .text_encoder .dtype
187186
188187 prompt = [prompt ] if isinstance (prompt , str ) else prompt
189- batch_size = len (prompt )
190188
191189 text_inputs = self .tokenizer (
192190 prompt ,
@@ -219,9 +217,6 @@ def _get_glm_embeds(
219217 ).hidden_states [- 2 ]
220218
221219 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
222- _ , seq_len , _ = prompt_embeds .shape
223- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
224- prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
225220 return prompt_embeds
226221
227222 def encode_prompt (
@@ -273,7 +268,11 @@ def encode_prompt(
273268 batch_size = prompt_embeds .shape [0 ]
274269
275270 if prompt_embeds is None :
276- prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , device , dtype )
271+ prompt_embeds = self ._get_glm_embeds (prompt , max_sequence_length , device , dtype )
272+
273+ seq_len = prompt_embeds .size (1 )
274+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
275+ prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
277276
278277 if do_classifier_free_guidance and negative_prompt_embeds is None :
279278 negative_prompt = negative_prompt or ""
@@ -291,9 +290,11 @@ def encode_prompt(
291290 " the batch size of `prompt`."
292291 )
293292
294- negative_prompt_embeds = self ._get_glm_embeds (
295- negative_prompt , num_images_per_prompt , max_sequence_length , device , dtype
296- )
293+ negative_prompt_embeds = self ._get_glm_embeds (negative_prompt , max_sequence_length , device , dtype )
294+
295+ seq_len = negative_prompt_embeds .size (1 )
296+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
297+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
297298
298299 return prompt_embeds , negative_prompt_embeds
299300
@@ -575,7 +576,7 @@ def __call__(
575576 if timesteps is None
576577 else np .array (timesteps )
577578 )
578- timesteps = timesteps .astype (np .int64 )
579+ timesteps = timesteps .astype (np .float32 )
579580 sigmas = timesteps / self .scheduler .config .num_train_timesteps if sigmas is None else sigmas
580581 mu = calculate_shift (
581582 image_seq_len ,
@@ -585,6 +586,7 @@ def __call__(
585586 )
586587 _ , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas , mu = mu )
587588 timesteps = torch .from_numpy (timesteps ).to (device )
589+ self ._num_timesteps = len (timesteps )
588590
589591 # Denoising loop
590592 transformer_dtype = self .transformer .dtype
0 commit comments