1212from torch .testing ._internal .common_utils import (
1313 instantiate_parametrized_tests ,
1414 parametrize ,
15+ skipIfRocmArch ,
1516 TEST_WITH_ROCM ,
1617)
1718from torch .testing ._internal .inductor_utils import HAS_CUDA
3233FP16_MAX_POS : float = torch .finfo (torch .float16 ).max
3334EPS : float = 1e-12
3435
36+ # fp8 data types for inductor based fp8 tests. This can be different
37+ # than the one used in eager mode.
38+ f8_type_pair = (torch .float8_e4m3fn , torch .float8_e5m2 )
39+ if torch .version .hip :
40+ arch = torch .cuda .get_device_properties (0 ).gcnArchName
41+ if "gfx94" in arch :
42+ # for gfx942, use fnuz data type.
43+ f8_type_pair = (torch .float8_e4m3fnuz , torch .float8_e5m2fnuz )
44+ elif "gfx120" in arch :
45+ # for gfx1200 and gfx1201, e4m3 is not supported on triton.
46+ f8_type_pair = (torch .float8_e5m2 ,)
3547
3648def _to_fp8_saturated (x : Tensor , float8_dtype : torch .dtype ) -> Tensor :
3749 # The default behavior in PyTorch for casting to `float8_e4m3fn`
@@ -180,10 +192,11 @@ def fp8_matmul_unwrapped(x):
180192 x = torch .rand (* x_shape , device = "cuda" , dtype = dtype ).to (e4m3_type )
181193 y_fp8 = compiled_fp8_matmul (x ) # noqa: F841
182194
195+ @skipIfRocmArch (("gfx1200" ,"gfx1201" ))
183196 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
184197 @parametrize ("dtype" , (torch .float16 , torch .bfloat16 , torch .float ))
185198 @parametrize ("shape" , ("15,3,13" , "4,2048,4096" ))
186- @parametrize ("dst_types" , [( torch . float8_e4m3fn , torch . float8_e5m2 ) ])
199+ @parametrize ("dst_types" , [f8_type_pair ])
187200 def test_valid_cast (self , dtype : torch .dtype , shape : str , dst_types : tuple ):
188201 dst_types = _fix_fp8_dtype_for_rocm (dst_types , device = "cuda" )
189202 e4m3 , e5m2 = dst_types
@@ -227,7 +240,7 @@ def fp8_cast(x, dtype):
227240
228241 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
229242 @parametrize ("src_dtype" , (torch .float16 , torch .bfloat16 , torch .float ))
230- @parametrize ("dst_dtype" , ( torch . float8_e4m3fn , torch . float8_e5m2 ) )
243+ @parametrize ("dst_dtype" , f8_type_pair )
231244 @parametrize ("shape" , ("16,16,16" , "4,2048,4096" ))
232245 def test_to_fp8_saturated (
233246 self , src_dtype : torch .dtype , dst_dtype : torch .dtype , shape : str
@@ -249,7 +262,7 @@ def fp8_saturated(x, dtype):
249262
250263 @unittest .skipIf (TEST_WITH_ROCM , "ROCm fails with accuracy issue" )
251264 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
252- @parametrize ("float8_dtype" , ( torch . float8_e4m3fn , torch . float8_e5m2 ) )
265+ @parametrize ("float8_dtype" , f8_type_pair )
253266 @parametrize ("shape" , ("1,1,15" , "1,10,15" , "1,10,512" , "1,10,4096" , "4,2048,4096" ))
254267 def test_amax_fp8_quant (self , float8_dtype : torch .dtype , shape : str ):
255268 float8_dtype = _fix_fp8_dtype_for_rocm (float8_dtype , device = "cuda" )
@@ -274,7 +287,7 @@ def amax_fp8(x: Tensor, scale: Tensor):
274287 torch .testing .assert_close (y_compiled .half (), y .half (), rtol = 1e-2 , atol = 1e-2 )
275288
276289 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
277- @parametrize ("float8_dtype" , ( torch . float8_e4m3fn , torch . float8_e5m2 ) )
290+ @parametrize ("float8_dtype" , f8_type_pair )
278291 @parametrize ("shape" , ("1,1,15" , "1,10,15" , "1,10,512" , "1,10,4096" , "4,2048,4096" ))
279292 def test_amax_along_with_fp8_quant (self , float8_dtype : torch .dtype , shape : str ):
280293 float8_dtype = _fix_fp8_dtype_for_rocm (float8_dtype , device = "cuda" )
@@ -305,7 +318,7 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
305318
306319 @unittest .skipIf (TEST_WITH_ROCM , "ROCm fails with accuracy issue" )
307320 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
308- @parametrize ("float8_dtype" , ( torch . float8_e4m3fn , torch . float8_e5m2 ) )
321+ @parametrize ("float8_dtype" , f8_type_pair )
309322 @parametrize ("amax_keep_dim" , (True , False ))
310323 @parametrize ("shape" , ("1,1,15" , "1,10,15" , "1,10,512" , "1,10,4096" , "4,2048,4096" ))
311324 def test_layernorm_fp8_quant (
@@ -347,7 +360,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
347360 )
348361
349362 @unittest .skipIf (not PLATFORM_SUPPORTS_FP8 , f8_msg )
350- @parametrize ("float8_dtype" , ( torch . float8_e4m3fn , torch . float8_e5m2 ) )
363+ @parametrize ("float8_dtype" , f8_type_pair )
351364 @parametrize ("shape" , ("4,2048,4096" ,))
352365 @parametrize ("keepdim" , (False , True ))
353366 def test_layernorm_fp8_quant_benchmark (
0 commit comments