21
21
)
22
22
from torchao .quantization .utils import compute_error
23
23
from torchao .sparsity .sparse_api import apply_fake_sparsity
24
+ from torchao .testing .utils import skip_if_rocm
24
25
from torchao .utils import (
25
26
TORCH_VERSION_AT_LEAST_2_8 ,
26
27
)
@@ -38,6 +39,7 @@ class TestInt4MarlinSparseTensor(TestCase):
38
39
def setUp (self ):
39
40
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
40
41
42
+ @skip_if_rocm ("ROCm enablement in progress" )
41
43
@parametrize ("config" , [BF16_ACT_CONFIG ])
42
44
@parametrize (
43
45
"sizes" ,
@@ -65,6 +67,7 @@ def test_linear(self, config, sizes):
65
67
quantized_and_compiled = compiled_linear (input )
66
68
self .assertTrue (compute_error (original , quantized_and_compiled ) > 20 )
67
69
70
+ @skip_if_rocm ("ROCm enablement in progress" )
68
71
@unittest .skip ("Fix later" )
69
72
@parametrize ("config" , [BF16_ACT_CONFIG ])
70
73
def test_to_device (self , config ):
@@ -81,6 +84,7 @@ def test_to_device(self, config):
81
84
quantize_ (linear , config )
82
85
linear .to (device )
83
86
87
+ @skip_if_rocm ("ROCm enablement in progress" )
84
88
@parametrize ("config" , [BF16_ACT_CONFIG ])
85
89
def test_module_path (self , config ):
86
90
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
0 commit comments