diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index eba8cc23b7e1..3a401c46fb5e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2107,7 +2107,7 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): if "unet" not in self.model_class.__name__.lower(): - return + pytest.skip("Test only applies to UNet.") # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] @@ -2117,7 +2117,7 @@ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): if "unet" not in self.model_class.__name__.lower(): - return + pytest.skip("Test only applies to UNet.") # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"]