3030from typing import Dict , List , Optional , Tuple , Union
3131
3232import numpy as np
33+ import pytest
3334import requests_mock
3435import safetensors .torch
3536import torch
@@ -938,8 +939,9 @@ def recursive_check(tuple_object, dict_object):
938939
939940 @require_torch_accelerator_with_training
940941 def test_enable_disable_gradient_checkpointing (self ):
942+ # Skip test if model does not support gradient checkpointing
941943 if not self .model_class ._supports_gradient_checkpointing :
942- return # Skip test if model does not support gradient checkpointing
944+ pytest . skip ( "Gradient checkpointing is not supported." )
943945
944946 init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
945947
@@ -957,8 +959,9 @@ def test_enable_disable_gradient_checkpointing(self):
957959
958960 @require_torch_accelerator_with_training
959961 def test_effective_gradient_checkpointing (self , loss_tolerance = 1e-5 , param_grad_tol = 5e-5 , skip : set [str ] = {}):
962+ # Skip test if model does not support gradient checkpointing
960963 if not self .model_class ._supports_gradient_checkpointing :
961- return # Skip test if model does not support gradient checkpointing
964+ pytest . skip ( "Gradient checkpointing is not supported." )
962965
963966 # enable deterministic behavior for gradient checkpointing
964967 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -1015,8 +1018,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
10151018 def test_gradient_checkpointing_is_applied (
10161019 self , expected_set = None , attention_head_dim = None , num_attention_heads = None , block_out_channels = None
10171020 ):
1021+ # Skip test if model does not support gradient checkpointing
10181022 if not self .model_class ._supports_gradient_checkpointing :
1019- return # Skip test if model does not support gradient checkpointing
1023+ pytest . skip ( "Gradient checkpointing is not supported." )
10201024
10211025 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10221026
@@ -1073,7 +1077,7 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
10731077 model = self .model_class (** init_dict ).to (torch_device )
10741078
10751079 if not issubclass (model .__class__ , PeftAdapterMixin ):
1076- return
1080+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
10771081
10781082 torch .manual_seed (0 )
10791083 output_no_lora = model (** inputs_dict , return_dict = False )[0 ]
@@ -1128,7 +1132,7 @@ def test_lora_wrong_adapter_name_raises_error(self):
11281132 model = self .model_class (** init_dict ).to (torch_device )
11291133
11301134 if not issubclass (model .__class__ , PeftAdapterMixin ):
1131- return
1135+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
11321136
11331137 denoiser_lora_config = LoraConfig (
11341138 r = 4 ,
@@ -1159,7 +1163,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d
11591163 model = self .model_class (** init_dict ).to (torch_device )
11601164
11611165 if not issubclass (model .__class__ , PeftAdapterMixin ):
1162- return
1166+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
11631167
11641168 denoiser_lora_config = LoraConfig (
11651169 r = rank ,
@@ -1196,7 +1200,7 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
11961200 model = self .model_class (** init_dict ).to (torch_device )
11971201
11981202 if not issubclass (model .__class__ , PeftAdapterMixin ):
1199- return
1203+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
12001204
12011205 denoiser_lora_config = LoraConfig (
12021206 r = 4 ,
@@ -1233,10 +1237,10 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
12331237
12341238 @require_torch_accelerator
12351239 def test_cpu_offload (self ):
1240+ if self .model_class ._no_split_modules is None :
1241+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
12361242 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
12371243 model = self .model_class (** config ).eval ()
1238- if model ._no_split_modules is None :
1239- return
12401244
12411245 model = model .to (torch_device )
12421246
@@ -1263,10 +1267,10 @@ def test_cpu_offload(self):
12631267
12641268 @require_torch_accelerator
12651269 def test_disk_offload_without_safetensors (self ):
1270+ if self .model_class ._no_split_modules is None :
1271+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
12661272 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
12671273 model = self .model_class (** config ).eval ()
1268- if model ._no_split_modules is None :
1269- return
12701274
12711275 model = model .to (torch_device )
12721276
@@ -1296,10 +1300,10 @@ def test_disk_offload_without_safetensors(self):
12961300
12971301 @require_torch_accelerator
12981302 def test_disk_offload_with_safetensors (self ):
1303+ if self .model_class ._no_split_modules is None :
1304+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
12991305 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
13001306 model = self .model_class (** config ).eval ()
1301- if model ._no_split_modules is None :
1302- return
13031307
13041308 model = model .to (torch_device )
13051309
@@ -1324,10 +1328,10 @@ def test_disk_offload_with_safetensors(self):
13241328
13251329 @require_torch_multi_accelerator
13261330 def test_model_parallelism (self ):
1331+ if self .model_class ._no_split_modules is None :
1332+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
13271333 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
13281334 model = self .model_class (** config ).eval ()
1329- if model ._no_split_modules is None :
1330- return
13311335
13321336 model = model .to (torch_device )
13331337
@@ -1426,10 +1430,10 @@ def test_sharded_checkpoints_with_variant(self):
14261430
14271431 @require_torch_accelerator
14281432 def test_sharded_checkpoints_device_map (self ):
1433+ if self .model_class ._no_split_modules is None :
1434+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
14291435 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
14301436 model = self .model_class (** config ).eval ()
1431- if model ._no_split_modules is None :
1432- return
14331437 model = model .to (torch_device )
14341438
14351439 torch .manual_seed (0 )
@@ -1497,7 +1501,7 @@ def test_variant_sharded_ckpt_right_format(self):
14971501 def test_layerwise_casting_training (self ):
14981502 def test_fn (storage_dtype , compute_dtype ):
14991503 if torch .device (torch_device ).type == "cpu" and compute_dtype == torch .bfloat16 :
1500- return
1504+ pytest . skip ( "Skipping test because CPU doesn't go well with bfloat16." )
15011505 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
15021506
15031507 model = self .model_class (** init_dict )
@@ -1617,6 +1621,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
16171621 @parameterized .expand ([False , True ])
16181622 @require_torch_accelerator
16191623 def test_group_offloading (self , record_stream ):
1624+ if not self .model_class ._supports_group_offloading :
1625+ pytest .skip ("Model does not support group offloading." )
1626+
16201627 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
16211628 torch .manual_seed (0 )
16221629
@@ -1633,8 +1640,6 @@ def run_forward(model):
16331640 return model (** inputs_dict )[0 ]
16341641
16351642 model = self .model_class (** init_dict )
1636- if not getattr (model , "_supports_group_offloading" , True ):
1637- return
16381643
16391644 model .to (torch_device )
16401645 output_without_group_offloading = run_forward (model )
@@ -1670,13 +1675,13 @@ def run_forward(model):
16701675 @require_torch_accelerator
16711676 @torch .no_grad ()
16721677 def test_group_offloading_with_layerwise_casting (self , record_stream , offload_type ):
1678+ if not self .model_class ._supports_group_offloading :
1679+ pytest .skip ("Model does not support group offloading." )
1680+
16731681 torch .manual_seed (0 )
16741682 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
16751683 model = self .model_class (** init_dict )
16761684
1677- if not getattr (model , "_supports_group_offloading" , True ):
1678- return
1679-
16801685 model .to (torch_device )
16811686 model .eval ()
16821687 _ = model (** inputs_dict )[0 ]
@@ -1698,13 +1703,13 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16981703 @require_torch_accelerator
16991704 @torch .no_grad ()
17001705 def test_group_offloading_with_disk (self , record_stream , offload_type ):
1706+ if not self .model_class ._supports_group_offloading :
1707+ pytest .skip ("Model does not support group offloading." )
1708+
17011709 torch .manual_seed (0 )
17021710 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
17031711 model = self .model_class (** init_dict )
17041712
1705- if not getattr (model , "_supports_group_offloading" , True ):
1706- return
1707-
17081713 torch .manual_seed (0 )
17091714 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
17101715 model = self .model_class (** init_dict )
0 commit comments