Skip to content

Commit 3e3387e

Browse files
committed
make style
1 parent 88abb39 commit 3e3387e

File tree

4 files changed

+26
-31
lines changed

4 files changed

+26
-31
lines changed

scripts/convert_cogview4_to_diffusers_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
162162
Returns:
163163
dict: The converted state dictionary compatible with Diffusers.
164164
"""
165-
ckpt = torch.load(ckpt_path, map_location="cpu",weights_only=False)
165+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
166166
mega = ckpt["model"]
167167

168168
new_state_dict = {}
@@ -260,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
260260
Returns:
261261
dict: The converted VAE state dictionary compatible with Diffusers.
262262
"""
263-
original_state_dict = torch.load(ckpt_path, map_location="cpu",weights_only=False)["state_dict"]
263+
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
264264
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
265265

266266

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@
345345
"CogVideoXPipeline",
346346
"CogVideoXVideoToVideoPipeline",
347347
"CogView3PlusPipeline",
348-
"CogView4Pipeline",
349348
"CogView4ControlPipeline",
349+
"CogView4Pipeline",
350350
"ConsisIDPipeline",
351351
"CycleDiffusionPipeline",
352352
"EasyAnimateControlPipeline",

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,12 @@
5050
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
5151
... )
5252
>>> prompt = "A bird in space"
53-
>>> image = pipe(
54-
... prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5)
55-
... ).images[0]
53+
>>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
5654
>>> image.save("cogview4-control.png")
5755
```
5856
"""
5957

58+
6059
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
6160
def calculate_shift(
6261
image_seq_len,
@@ -101,19 +100,10 @@ def retrieve_timesteps(
101100
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
102101
second element is the number of inference steps.
103102
"""
104-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105-
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
106-
107103
if timesteps is not None and sigmas is not None:
108-
if not accepts_timesteps and not accepts_sigmas:
109-
raise ValueError(
110-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
111-
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
112-
)
113-
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
114-
timesteps = scheduler.timesteps
115-
num_inference_steps = len(timesteps)
116-
elif timesteps is not None and sigmas is None:
104+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
105+
if timesteps is not None:
106+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
117107
if not accepts_timesteps:
118108
raise ValueError(
119109
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -122,8 +112,9 @@ def retrieve_timesteps(
122112
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
123113
timesteps = scheduler.timesteps
124114
num_inference_steps = len(timesteps)
125-
elif timesteps is None and sigmas is not None:
126-
if not accepts_sigmas:
115+
elif sigmas is not None:
116+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
117+
if not accept_sigmas:
127118
raise ValueError(
128119
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129120
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -182,7 +173,6 @@ def __init__(
182173
def _get_glm_embeds(
183174
self,
184175
prompt: Union[str, List[str]] = None,
185-
num_images_per_prompt: int = 1,
186176
max_sequence_length: int = 1024,
187177
device: Optional[torch.device] = None,
188178
dtype: Optional[torch.dtype] = None,
@@ -191,7 +181,6 @@ def _get_glm_embeds(
191181
dtype = dtype or self.text_encoder.dtype
192182

193183
prompt = [prompt] if isinstance(prompt, str) else prompt
194-
batch_size = len(prompt)
195184

196185
text_inputs = self.tokenizer(
197186
prompt,
@@ -224,9 +213,6 @@ def _get_glm_embeds(
224213
).hidden_states[-2]
225214

226215
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
227-
_, seq_len, _ = prompt_embeds.shape
228-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
229-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
230216
return prompt_embeds
231217

232218
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
@@ -277,8 +263,13 @@ def encode_prompt(
277263
batch_size = len(prompt)
278264
else:
279265
batch_size = prompt_embeds.shape[0]
266+
280267
if prompt_embeds is None:
281-
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
268+
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
269+
270+
seq_len = prompt_embeds.size(1)
271+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
272+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
282273

283274
if do_classifier_free_guidance and negative_prompt_embeds is None:
284275
negative_prompt = negative_prompt or ""
@@ -296,9 +287,11 @@ def encode_prompt(
296287
" the batch size of `prompt`."
297288
)
298289

299-
negative_prompt_embeds = self._get_glm_embeds(
300-
negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
301-
)
290+
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
291+
292+
seq_len = negative_prompt_embeds.size(1)
293+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
294+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
302295

303296
return prompt_embeds, negative_prompt_embeds
304297

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def from_pretrained(cls, *args, **kwargs):
362362
requires_backends(cls, ["torch", "transformers"])
363363

364364

365-
class CogView4Pipeline(metaclass=DummyObject):
365+
class CogView4ControlPipeline(metaclass=DummyObject):
366366
_backends = ["torch", "transformers"]
367367

368368
def __init__(self, *args, **kwargs):
@@ -376,7 +376,8 @@ def from_config(cls, *args, **kwargs):
376376
def from_pretrained(cls, *args, **kwargs):
377377
requires_backends(cls, ["torch", "transformers"])
378378

379-
class CogView4ControlPipeline(metaclass=DummyObject):
379+
380+
class CogView4Pipeline(metaclass=DummyObject):
380381
_backends = ["torch", "transformers"]
381382

382383
def __init__(self, *args, **kwargs):
@@ -390,6 +391,7 @@ def from_config(cls, *args, **kwargs):
390391
def from_pretrained(cls, *args, **kwargs):
391392
requires_backends(cls, ["torch", "transformers"])
392393

394+
393395
class ConsisIDPipeline(metaclass=DummyObject):
394396
_backends = ["torch", "transformers"]
395397

0 commit comments

Comments
 (0)