@@ -2188,17 +2188,6 @@ def test_correct_lora_configs_with_different_ranks(self):
21882188 self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
21892189 self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
21902190
2191- @property
2192- def supports_text_encoder_lora (self ):
2193- return (
2194- len (
2195- {"text_encoder" , "text_encoder_2" , "text_encoder_3" }.intersection (
2196- self .pipeline_class ._lora_loadable_modules
2197- )
2198- )
2199- != 0
2200- )
2201-
22022191 def test_layerwise_casting_inference_denoiser (self ):
22032192 from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
22042193
@@ -2251,15 +2240,13 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
22512240 pipe_fp32 = initialize_pipeline (storage_dtype = None )
22522241 pipe_fp32 (** inputs , generator = torch .manual_seed (0 ))[0 ]
22532242
2254- # MPS doesn't support float8 yet.
2255- if torch_device not in {"mps" }:
2256- pipe_float8_e4m3_fp32 = initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .float32 )
2257- pipe_float8_e4m3_fp32 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2243+ pipe_float8_e4m3_fp32 = initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .float32 )
2244+ pipe_float8_e4m3_fp32 (** inputs , generator = torch .manual_seed (0 ))[0 ]
22582245
2259- pipe_float8_e4m3_bf16 = initialize_pipeline (
2260- storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16
2261- )
2262- pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2246+ pipe_float8_e4m3_bf16 = initialize_pipeline (
2247+ storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16
2248+ )
2249+ pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
22632250
22642251 @require_peft_version_greater ("0.14.0" )
22652252 def test_layerwise_casting_peft_input_autocast_denoiser (self ):
@@ -2284,7 +2271,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
22842271 apply_layerwise_casting ,
22852272 )
22862273
2287- storage_dtype = torch .float8_e4m3fn if not torch_device == "mps" else torch . bfloat16
2274+ storage_dtype = torch .float8_e4m3fn
22882275 compute_dtype = torch .float32
22892276
22902277 def check_module (denoiser ):
0 commit comments