@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528        test_fn (torch .float8_e5m2 , torch .float32 )
15291529        test_fn (torch .float8_e4m3fn , torch .bfloat16 )
15301530
1531+     @torch .no_grad () 
15311532    def  test_layerwise_casting_inference (self ):
15321533        from  diffusers .hooks .layerwise_casting  import  DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS 
15331534
15341535        torch .manual_seed (0 )
15351536        config , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1536-         model  =  self .model_class (** config ).eval ()
1537-         model  =  model .to (torch_device )
1538-         base_slice  =  model (** inputs_dict )[0 ].flatten ().detach ().cpu ().numpy ()
1537+         model  =  self .model_class (** config )
1538+         model .eval ()
1539+         model .to (torch_device )
1540+         base_slice  =  model (** inputs_dict )[0 ].detach ().flatten ().cpu ().numpy ()
15391541
15401542        def  check_linear_dtype (module , storage_dtype , compute_dtype ):
15411543            patterns_to_check  =  DEFAULT_SKIP_MODULES_PATTERN 
@@ -1706,10 +1708,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061708        if  not  self .model_class ._supports_group_offloading :
17071709            pytest .skip ("Model does not support group offloading." )
17081710
1709-         torch .manual_seed (0 )
1710-         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1711-         model  =  self .model_class (** init_dict )
1712- 
17131711        torch .manual_seed (0 )
17141712        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
17151713        model  =  self .model_class (** init_dict )
@@ -1725,7 +1723,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251723                ** additional_kwargs ,
17261724            )
17271725            has_safetensors  =  glob .glob (f"{ tmpdir }  /*.safetensors" )
1728-             assert   has_safetensors , "No safetensors found in the directory." 
1726+             self . assertTrue ( len ( has_safetensors )  >   0 , "No safetensors found in the offload  directory." ) 
17291727            _  =  model (** inputs_dict )[0 ]
17301728
17311729    def  test_auto_model (self , expected_max_diff = 5e-5 ):
0 commit comments