@@ -20,15 +20,16 @@ def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
2020
2121
2222@pytest .mark .skipif (
23- not current_platform .is_cuda_alike (),
24- reason = "Currently only kernels on CUDA and ROCm " ,
23+ not current_platform .is_cuda_alike () and not current_platform . is_xpu () ,
24+ reason = "Currently only kernels on CUDA, ROCm and XPU " ,
2525)
2626def test_rms_norm_registration ():
2727 expected = {
2828 "native" : True ,
29- "vllm_c" : True ,
29+ "vllm_c" : current_platform . is_cuda_alike () ,
3030 "aiter" : current_platform .is_rocm (),
3131 "oink" : False ,
32+ "xpu_kernels" : current_platform .is_xpu (),
3233 }
3334
3435 actual = {
@@ -43,13 +44,13 @@ def test_rms_norm_registration():
4344@pytest .mark .parametrize ("hidden_size" , [16 , 4096 , 8192 ])
4445@pytest .mark .parametrize ("epsilon" , [1e-6 , 1e-5 ])
4546@pytest .mark .skipif (
46- not current_platform .is_cuda_alike (),
47- reason = "Currently only kernels on CUDA and ROCm " ,
47+ not current_platform .is_cuda_alike () and not current_platform . is_xpu () ,
48+ reason = "Currently only kernels on CUDA, ROCm and XPU " ,
4849)
4950class TestRMSNorm :
5051 @classmethod
5152 def setup_class (cls , ** kwargs ):
52- torch .set_default_device ("cuda" )
53+ torch .set_default_device (current_platform . device_name )
5354
5455 def test_native_semantics (self , dtype , n_tokens , hidden_size , epsilon ):
5556 x , weight = rms_norm_inputs (4 , 8 , dtype )
@@ -70,7 +71,7 @@ def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
7071 out4 = rms_norm_native (x , None , epsilon = epsilon )
7172 torch .testing .assert_close (out3 , out4 )
7273
73- @pytest .mark .parametrize ("provider" , ["vllm_c" , "aiter" ])
74+ @pytest .mark .parametrize ("provider" , ["vllm_c" , "aiter" , "xpu_kernels" ])
7475 def test_impls (self , dtype , n_tokens , hidden_size , epsilon , provider ):
7576 impl = ir .ops .rms_norm .impls [provider ]
7677 if not impl .supported :
@@ -115,7 +116,7 @@ def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
115116 atol = 2e-4 ,
116117 )
117118
118- @pytest .mark .parametrize ("provider" , ["vllm_c" , "aiter" , "native" ])
119+ @pytest .mark .parametrize ("provider" , ["vllm_c" , "aiter" , "xpu_kernels" , " native" ])
119120 def test_torch_opcheck (self , dtype , n_tokens , hidden_size , epsilon , provider ):
120121 if not ir .ops .rms_norm .impls [provider ].supported :
121122 pytest .skip (f"{ provider } impl not supported on this platform" )
0 commit comments