Skip to content

Commit 711dded

Browse files
committed
make style
1 parent f91cfcf commit 711dded

File tree

5 files changed

+32
-21
lines changed

5 files changed

+32
-21
lines changed

src/diffusers/models/transformers/transformer_omnigen.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

1818
import torch
19+
import torch.nn.functional as F
1920
import torch.utils.checkpoint
2021
from torch import nn
21-
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import PeftAdapterMixin
@@ -91,7 +91,11 @@ def __init__(
9191
self.pos_embed_max_size = pos_embed_max_size
9292

9393
pos_embed = get_2d_sincos_pos_embed(
94-
embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt"
94+
embed_dim,
95+
self.pos_embed_max_size,
96+
base_size=base_size,
97+
interpolation_scale=self.interpolation_scale,
98+
output_type="pt",
9599
)
96100
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
97101

@@ -227,7 +231,7 @@ def apply_rotary_emb(
227231
Returns:
228232
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
229233
"""
230-
234+
231235
cos, sin = freqs_cis # [S, D]
232236
if len(cos.shape) == 2:
233237
cos = cos[None, None]
@@ -241,10 +245,10 @@ def apply_rotary_emb(
241245
x1 = x[..., : x.shape[-1] // 2]
242246
x2 = x[..., x.shape[-1] // 2 :]
243247
x_rotated = torch.cat((-x2, x1), dim=-1)
244-
248+
245249
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
246250
return out
247-
251+
248252

249253
class OmniGenAttnProcessor2_0:
250254
r"""
@@ -264,7 +268,6 @@ def __call__(
264268
attention_mask: Optional[torch.Tensor] = None,
265269
image_rotary_emb: Optional[torch.Tensor] = None,
266270
) -> torch.Tensor:
267-
268271
batch_size, sequence_length, _ = hidden_states.shape
269272

270273
# Get Query-Key-Value Pair
@@ -674,9 +677,13 @@ def forward(
674677
image_rotary_emb = self.rotary_emb(hidden_states, position_ids)
675678
for decoder_layer in self.layers:
676679
if torch.is_grad_enabled() and self.gradient_checkpointing:
677-
hidden_states = self._gradient_checkpointing_func(decoder_layer, hidden_states, attention_mask, image_rotary_emb)
680+
hidden_states = self._gradient_checkpointing_func(
681+
decoder_layer, hidden_states, attention_mask, image_rotary_emb
682+
)
678683
else:
679-
hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
684+
hidden_states = decoder_layer(
685+
hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb
686+
)
680687

681688
hidden_states = self.norm(hidden_states)
682689

src/diffusers/pipelines/consisid/pipeline_consisid.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,14 @@
4848
>>> from huggingface_hub import snapshot_download
4949
5050
>>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
51-
>>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
52-
... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
53-
... )
51+
>>> (
52+
... face_helper_1,
53+
... face_helper_2,
54+
... face_clip_model,
55+
... face_main_model,
56+
... eva_transform_mean,
57+
... eva_transform_std,
58+
... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
5459
>>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
5560
>>> pipe.to("cuda")
5661

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import numpy as np
19-
import PIL
2019
import torch
2120
from transformers import LlamaTokenizer
2221

@@ -223,9 +222,9 @@ def check_inputs(
223222
if use_input_image_size_as_output:
224223
if input_images is None or input_images[0] is None:
225224
raise ValueError(
226-
f"`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
227-
)
228-
225+
"`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
226+
)
227+
229228
if callback_on_step_end_tensor_inputs is not None and not all(
230229
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
231230
):

src/diffusers/pipelines/omnigen/processor_omnigen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, text_tokenizer, max_image_size: int = 1024):
6262
)
6363

6464
self.collator = OmniGenCollator()
65-
65+
6666
def reset_max_image_size(self, max_image_size):
6767
self.max_image_size = max_image_size
6868
self.image_transform = transforms.Compose(

tests/models/transformers/test_models_transformer_omnigen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
enable_full_determinism()
2727

28+
2829
class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
2930
model_class = OmniGenTransformer2DModel
3031
main_input_name = "hidden_states"
@@ -42,11 +43,11 @@ def dummy_input(self):
4243
timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device)
4344
input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device)
4445
input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)]
45-
input_image_sizes = {0: [[0, 0+height*width//2//2]]}
46+
input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]}
4647

47-
attn_seq_length = sequence_length + 1 + height*width//2//2
48+
attn_seq_length = sequence_length + 1 + height * width // 2 // 2
4849
attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device)
49-
position_ids = torch.LongTensor([list(range(attn_seq_length))]*batch_size).to(torch_device)
50+
position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device)
5051

5152
return {
5253
"hidden_states": hidden_states,
@@ -77,12 +78,11 @@ def prepare_init_args_and_inputs_for_common(self):
7778
"vocab_size": 100,
7879
"in_channels": 4,
7980
"time_step_dim": 4,
80-
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))}
81+
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
8182
}
8283
inputs_dict = self.dummy_input
8384
return init_dict, inputs_dict
8485

8586
def test_gradient_checkpointing_is_applied(self):
8687
expected_set = {"OmniGenTransformer2DModel"}
8788
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
88-

0 commit comments

Comments
 (0)