@@ -236,7 +236,7 @@ def test_quantization(self):
236236                ("uint7wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
237237            ]
238238
239-             if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
239+             if  TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
240240                QUANTIZATION_TYPES_TO_TEST .extend ([
241241                    ("float8wo_e5m2" , np .array ([0.4590 , 0.5273 , 0.5547 , 0.4219 , 0.4375 , 0.6406 , 0.4316 , 0.4512 , 0.5625 ])),
242242                    ("float8wo_e4m3" , np .array ([0.4648 , 0.5234 , 0.5547 , 0.4219 , 0.4414 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
@@ -753,7 +753,7 @@ def test_quantization(self):
753753            ("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
754754        ]
755755
756-         if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
756+         if  TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
757757            QUANTIZATION_TYPES_TO_TEST .extend ([
758758                ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
759759                ("fp5_e3m1" , np .array ([0.0527 , 0.0762 , 0.1309 , 0.0449 , 0.0645 , 0.1328 , 0.0566 , 0.0723 , 0.125 , 0.0566 , 0.0703 , 0.1328 , 0.0566 , 0.0742 , 0.1348 , 0.0566 , 0.3633 , 0.7617 , 0.5273 , 0.4277 , 0.7891 , 0.5469 , 0.4375 , 0.8008 , 0.5586 , 0.4336 , 0.7383 , 0.5156 , 0.3906 , 0.6992 , 0.5156 , 0.4375 ])),
0 commit comments