Skip to content

Commit e97a83e

Browse files
committed
Remove unrelated changes.
1 parent 0c91c1a commit e97a83e

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

tests/lora/utils.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,17 +2188,6 @@ def test_correct_lora_configs_with_different_ranks(self):
21882188
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
21892189
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
21902190

2191-
@property
2192-
def supports_text_encoder_lora(self):
2193-
return (
2194-
len(
2195-
{"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(
2196-
self.pipeline_class._lora_loadable_modules
2197-
)
2198-
)
2199-
!= 0
2200-
)
2201-
22022191
def test_layerwise_casting_inference_denoiser(self):
22032192
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
22042193

@@ -2251,15 +2240,13 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
22512240
pipe_fp32 = initialize_pipeline(storage_dtype=None)
22522241
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
22532242

2254-
# MPS doesn't support float8 yet.
2255-
if torch_device not in {"mps"}:
2256-
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
2257-
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
2243+
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
2244+
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
22582245

2259-
pipe_float8_e4m3_bf16 = initialize_pipeline(
2260-
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
2261-
)
2262-
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
2246+
pipe_float8_e4m3_bf16 = initialize_pipeline(
2247+
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
2248+
)
2249+
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
22632250

22642251
@require_peft_version_greater("0.14.0")
22652252
def test_layerwise_casting_peft_input_autocast_denoiser(self):
@@ -2284,7 +2271,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
22842271
apply_layerwise_casting,
22852272
)
22862273

2287-
storage_dtype = torch.float8_e4m3fn if not torch_device == "mps" else torch.bfloat16
2274+
storage_dtype = torch.float8_e4m3fn
22882275
compute_dtype = torch.float32
22892276

22902277
def check_module(denoiser):

0 commit comments

Comments
 (0)