Skip to content

Commit 6e16041

Browse files
committed
style
1 parent eea0436 commit 6e16041

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def test_forward_with_norm_groups(self):
114114
def test_set_attn_processor_for_determinism(self):
115115
return
116116

117-
@unittest.skip("Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'")
117+
@unittest.skip(
118+
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
119+
)
118120
def test_layerwise_casting_training(self):
119121
return super().test_layerwise_casting_training()
120122

tests/models/test_modeling_common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ def test_variant_sharded_ckpt_right_format(self):
13381338
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13401340

1341-
def test_layerwise_casting_training(self):
1341+
def test_layerwise_casting_training(self):
13421342
def test_fn(storage_dtype, compute_dtype):
13431343
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
13441344
return
@@ -1360,15 +1360,14 @@ def test_fn(storage_dtype, compute_dtype):
13601360
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
13611361
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
13621362
loss = torch.nn.functional.mse_loss(output, noise)
1363-
1363+
13641364
loss.backward()
13651365

13661366
test_fn(torch.float16, torch.float32)
13671367
test_fn(torch.float8_e4m3fn, torch.float32)
13681368
test_fn(torch.float8_e5m2, torch.float32)
13691369
test_fn(torch.float8_e4m3fn, torch.bfloat16)
13701370

1371-
13721371
def test_layerwise_casting_inference(self):
13731372
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
13741373

0 commit comments

Comments
 (0)