We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d8c617c commit 807c8acCopy full SHA for 807c8ac
src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -213,9 +213,7 @@ def _get_glm_embeds(
213
device=text_input_ids.device,
214
)
215
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
216
- prompt_embeds = self.text_encoder(
217
- text_input_ids.to(self.text_encoder.device), output_hidden_states=True
218
- ).hidden_states[-2]
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
219
220
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
221
return prompt_embeds
0 commit comments