Skip to content

Commit c58267e

Browse files
committed
style and quality pass
1 parent 7808ee0 commit c58267e

File tree

4 files changed

+20
-38
lines changed

4 files changed

+20
-38
lines changed

src/diffusers/models/transformers/transformer_bria.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import numpy as np
44
import torch
55
import torch.nn as nn
6+
from packaging import version
67

8+
from diffusers import __version__ as diffusers_version
79
from diffusers.configuration_utils import ConfigMixin, register_to_config
810
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
911
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
@@ -13,9 +15,6 @@
1315
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
1416
from diffusers.pipelines.bria.bria_utils import FluxPosEmbed as EmbedND
1517
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
16-
from diffusers import __version__ as diffusers_version
17-
from packaging import version
18-
1918

2019

2120
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -246,7 +245,7 @@ def __init__(
246245
# def _set_gradient_checkpointing(self, module, enable=False):
247246
# if hasattr(module, "gradient_checkpointing"):
248247
# module.gradient_checkpointing = enable
249-
248+
250249
def forward(
251250
self,
252251
hidden_states: torch.Tensor,
@@ -324,14 +323,13 @@ def forward(
324323
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
325324

326325
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
327-
326+
328327
if len(txt_ids.shape) == 3:
329328
txt_ids = txt_ids[0]
330329

331330
if len(img_ids.shape) == 3:
332331
img_ids = img_ids[0]
333332

334-
335333
ids = torch.cat((txt_ids, img_ids), dim=0)
336334
image_rotary_emb = self.pos_embed(ids)
337335

@@ -408,9 +406,9 @@ def custom_forward(*inputs):
408406
else:
409407
if version.parse(diffusers_version) < version.parse("0.35.0.dev0"):
410408
hidden_states = block(
411-
hidden_states=hidden_states,
412-
temb=temb,
413-
image_rotary_emb=image_rotary_emb,
409+
hidden_states=hidden_states,
410+
temb=temb,
411+
image_rotary_emb=image_rotary_emb,
414412
)
415413
else:
416414
encoder_hidden_states, hidden_states = block(

src/diffusers/pipelines/bria/pipeline_bria.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
T5TokenizerFast,
1010
)
1111

12-
import diffusers
1312
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
1413
from diffusers.image_processor import VaeImageProcessor
1514
from diffusers.loaders import FluxLoraLoaderMixin
@@ -83,6 +82,7 @@ class BriaPipeline(FluxPipeline):
8382
Tokenizer of class
8483
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
8584
"""
85+
8686
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
8787
_optional_components = ["image_encoder", "feature_extractor"]
8888
_callback_tensor_inputs = ["latents", "prompt_embeds"]
@@ -455,8 +455,6 @@ def __call__(
455455
latent_image_ids = latent_image_ids[0]
456456
if len(text_ids.shape) == 3:
457457
text_ids = text_ids[0]
458-
459-
460458

461459
# 6. Denoising loop
462460
with self.progress_bar(total=num_inference_steps) as progress_bar:

tests/models/transformers/test_models_transformer_bria.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def output_shape(self):
120120

121121
def prepare_init_args_and_inputs_for_common(self):
122122
init_dict = {
123-
124123
"patch_size": 1,
125124
"in_channels": 4,
126125
"num_layers": 1,
@@ -130,7 +129,6 @@ def prepare_init_args_and_inputs_for_common(self):
130129
"joint_attention_dim": 32,
131130
"pooled_projection_dim": None,
132131
"axes_dims_rope": [0, 4, 4],
133-
134132
}
135133

136134
inputs_dict = self.dummy_input
@@ -163,8 +161,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
163161
torch.allclose(output_1, output_2, atol=1e-5),
164162
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
165163
)
166-
167-
164+
168165
def test_gradient_checkpointing_is_applied(self):
169166
expected_set = {"BriaTransformer2DModel"}
170167
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/pipelines/bria/test_pipeline_bria.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
import gc
16-
import unittest
1716
import tempfile
17+
import unittest
1818

1919
import numpy as np
2020
import torch
@@ -32,18 +32,17 @@
3232
enable_full_determinism,
3333
nightly,
3434
numpy_cosine_similarity_distance,
35-
require_torch_gpu,
3635
require_accelerator,
36+
require_torch_gpu,
3737
slow,
3838
torch_device,
3939
)
4040

41-
4241
# from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
43-
from tests.pipelines.test_pipelines_common import PipelineTesterMixin,to_np
42+
from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
4443

45-
enable_full_determinism()
4644

45+
enable_full_determinism()
4746

4847

4948
class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@@ -56,7 +55,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5655
test_xformers_attention = False
5756
test_layerwise_casting = True
5857
test_group_offloading = True
59-
58+
6059
def get_dummy_components(self):
6160
torch.manual_seed(0)
6261
transformer = BriaTransformer2DModel(
@@ -123,8 +122,10 @@ def get_dummy_inputs(self, device, seed=0):
123122
"output_type": "np",
124123
}
125124
return inputs
125+
126126
def test_encode_prompt_works_in_isolation(self):
127127
pass
128+
128129
def test_bria_different_prompts(self):
129130
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
130131
inputs = self.get_dummy_inputs(torch_device)
@@ -135,9 +136,6 @@ def test_bria_different_prompts(self):
135136
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
136137
assert max_diff > 1e-6
137138

138-
139-
140-
141139
def test_image_output_shape(self):
142140
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
143141
inputs = self.get_dummy_inputs(torch_device)
@@ -194,11 +192,7 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
194192
self.assertLess(
195193
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
196194
)
197-
198-
199-
200195

201-
202196
def test_bria_image_output_shape(self):
203197
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
204198
inputs = self.get_dummy_inputs(torch_device)
@@ -217,9 +211,9 @@ def test_to_dtype(self):
217211
components = self.get_dummy_components()
218212
pipe = self.pipeline_class(**components)
219213
pipe.set_progress_bar_config(disable=None)
220-
214+
221215
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
222-
self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False,True,True])
216+
self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [False, True, True])
223217

224218
def test_torch_dtype_dict(self):
225219
components = self.get_dummy_components()
@@ -243,8 +237,6 @@ def test_torch_dtype_dict(self):
243237
self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
244238
self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
245239

246-
247-
248240

249241
@slow
250242
@require_torch_gpu
@@ -265,9 +257,7 @@ def tearDown(self):
265257
def get_inputs(self, device, seed=0):
266258
generator = torch.Generator(device="cpu").manual_seed(seed)
267259
prompt_embeds = torch.load(
268-
hf_hub_download(
269-
repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt"
270-
)
260+
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="bria_prompt_embeds.pt")
271261
).to(device)
272262
return {
273263
"prompt_embeds": prompt_embeds,
@@ -324,7 +314,7 @@ def test_bria_inference_bf16(self):
324314
)
325315
max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
326316
self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}")
327-
317+
328318
def test_to_dtype(self):
329319
components = self.get_dummy_components()
330320
pipe = self.pipeline_class(**components)
@@ -396,4 +386,3 @@ def test_bria_inference(self):
396386

397387
max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
398388
self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}")
399-

0 commit comments

Comments
 (0)