Skip to content

Commit 4003b9c

Browse files
committed
apply suggestions from review
1 parent da420fb commit 4003b9c

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def __init__(
177177
def _get_glm_embeds(
178178
self,
179179
prompt: Union[str, List[str]] = None,
180-
num_images_per_prompt: int = 1,
181180
max_sequence_length: int = 1024,
182181
device: Optional[torch.device] = None,
183182
dtype: Optional[torch.dtype] = None,
@@ -186,7 +185,6 @@ def _get_glm_embeds(
186185
dtype = dtype or self.text_encoder.dtype
187186

188187
prompt = [prompt] if isinstance(prompt, str) else prompt
189-
batch_size = len(prompt)
190188

191189
text_inputs = self.tokenizer(
192190
prompt,
@@ -219,9 +217,6 @@ def _get_glm_embeds(
219217
).hidden_states[-2]
220218

221219
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
222-
_, seq_len, _ = prompt_embeds.shape
223-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
224-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
225220
return prompt_embeds
226221

227222
def encode_prompt(
@@ -273,7 +268,11 @@ def encode_prompt(
273268
batch_size = prompt_embeds.shape[0]
274269

275270
if prompt_embeds is None:
276-
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
271+
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
272+
273+
seq_len = prompt_embeds.size(1)
274+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
275+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
277276

278277
if do_classifier_free_guidance and negative_prompt_embeds is None:
279278
negative_prompt = negative_prompt or ""
@@ -291,9 +290,11 @@ def encode_prompt(
291290
" the batch size of `prompt`."
292291
)
293292

294-
negative_prompt_embeds = self._get_glm_embeds(
295-
negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
296-
)
293+
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
294+
295+
seq_len = negative_prompt_embeds.size(1)
296+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
297+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
297298

298299
return prompt_embeds, negative_prompt_embeds
299300

@@ -575,7 +576,7 @@ def __call__(
575576
if timesteps is None
576577
else np.array(timesteps)
577578
)
578-
timesteps = timesteps.astype(np.int64)
579+
timesteps = timesteps.astype(np.float32)
579580
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
580581
mu = calculate_shift(
581582
image_seq_len,
@@ -585,6 +586,7 @@ def __call__(
585586
)
586587
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
587588
timesteps = torch.from_numpy(timesteps).to(device)
589+
self._num_timesteps = len(timesteps)
588590

589591
# Denoising loop
590592
transformer_dtype = self.transformer.dtype

0 commit comments

Comments
 (0)