1
1
# Owner(s): ["module: intel"]
2
2
3
+ import contextlib
3
4
import itertools
4
5
import math
5
6
import unittest
13
14
instantiate_device_type_tests ,
14
15
precisionOverride ,
15
16
)
16
- from torch .testing ._internal .common_dtype import floating_and_complex_types_and
17
- from torch .testing ._internal .common_mkldnn import bf32_on_and_off
17
+ from torch .testing ._internal .common_dtype import (
18
+ floating_and_complex_types_and ,
19
+ floating_types_and ,
20
+ )
21
+ from torch .testing ._internal .common_mkldnn import reduced_f32_on_and_off
18
22
from torch .testing ._internal .common_utils import (
19
23
IS_WINDOWS ,
20
24
parametrize ,
@@ -98,7 +102,7 @@ def preferred_linalg_library(self):
98
102
@precisionOverride ({torch .half : 0.05 , torch .bfloat16 : 0.05 })
99
103
@dtypes (* floating_and_complex_types_and (torch .bfloat16 , torch .half ))
100
104
@tf32_on_and_off (0.05 )
101
- @bf32_on_and_off (0.05 )
105
+ @reduced_f32_on_and_off (0.05 )
102
106
def addbmm (self , device , dtype ):
103
107
num_batches = 2
104
108
M , N , O = 16 , 17 , 18
@@ -392,6 +396,83 @@ def ck_blas_library(self):
392
396
pass
393
397
394
398
399
+ @precisionOverride (
400
+ {
401
+ torch .double : 1e-8 ,
402
+ torch .float : 1e-4 ,
403
+ torch .bfloat16 : 5e-2 ,
404
+ torch .half : 5e-2 ,
405
+ torch .cfloat : 1e-4 ,
406
+ torch .cdouble : 1e-8 ,
407
+ }
408
+ )
409
+ @dtypes (* floating_types_and (torch .bfloat16 , torch .half ))
410
+ @tf32_on_and_off (0.05 )
411
+ @reduced_f32_on_and_off (0.05 )
412
+ def addmm_relu_tunableop_rocm (self , device , dtype ):
413
+ with self ._tunableop_ctx ():
414
+ torch .xpu .tunable .set_rotating_buffer_size (0 )
415
+ torch .xpu .tunable .set_max_tuning_iterations (1 )
416
+ self ._test_addmm_impl (torch ._addmm_activation , "relu" , device , dtype )
417
+
418
+
419
+ def get_tunableop_untuned_filename ():
420
+ import os
421
+
422
+ ordinal = torch .xpu .current_device ()
423
+ untuned_filename_env = os .getenv ("PYTORCH_TUNABLEOP_UNTUNED_FILENAME" )
424
+ untuned_filename_base , _ , _ = untuned_filename_env .rpartition ("." )
425
+ untuned_filename = f"{ untuned_filename_base } { ordinal } .csv"
426
+ return untuned_filename
427
+
428
+
429
+ @contextlib .contextmanager
430
+ def __tunableop_ctx (self ):
431
+ # Initialize and then tear down TunableOp
432
+ import glob
433
+ import os
434
+
435
+ self ._set_tunableop_defaults ()
436
+ torch .xpu .tunable .enable (True )
437
+
438
+ try :
439
+ yield
440
+ finally :
441
+ # disables TunableOp
442
+ torch .xpu .tunable .enable (False )
443
+
444
+ # clean up, remove any files that were generated
445
+ results_filename = torch .xpu .tunable .get_filename ()
446
+ results_filename_pattern , _ , _ = results_filename .rpartition ("." )
447
+ untuned_filename = get_tunableop_untuned_filename ()
448
+ untuned_filename_pattern , _ , _ = untuned_filename .rpartition ("." )
449
+ patterns = [
450
+ f"{ results_filename_pattern [:- 1 ]} *.csv" ,
451
+ f"{ untuned_filename_pattern [:- 1 ]} *.csv" ,
452
+ ]
453
+ files = [f for pattern in patterns for f in glob .glob (pattern )]
454
+ for file in files :
455
+ try :
456
+ os .remove (file )
457
+ # NB: The file is locked on Windows
458
+ except (FileNotFoundError , PermissionError ):
459
+ pass
460
+
461
+ # undo all the environment variables set
462
+ # loop through a list of potentially used
463
+ # environment variables.
464
+ env_list = [
465
+ "PYTORCH_TUNABLEOP_BLAS_LOG" ,
466
+ "PYTORCH_TUNABLEOP_NUMERICAL_CHECK" ,
467
+ "PYTORCH_TUNABLEOP_UNTUNED_FILENAME" ,
468
+ ]
469
+ for env in env_list :
470
+ try :
471
+ del os .environ [env ]
472
+ except KeyError :
473
+ pass
474
+
475
+
395
476
with XPUPatchForImport (False ):
396
477
from test_linalg import TestLinalg
397
478
@@ -410,6 +491,8 @@ def ck_blas_library(self):
410
491
TestLinalg .test_matmul_small_brute_force_2d_Nd = matmul_small_brute_force_2d_Nd
411
492
TestLinalg .test_matmul_small_brute_force_3d_Nd = matmul_small_brute_force_3d_Nd
412
493
TestLinalg .test_ck_blas_library = ck_blas_library
494
+ TestLinalg .test_addmm_relu_tunableop_rocm = addmm_relu_tunableop_rocm
495
+ TestLinalg ._tunableop_ctx = __tunableop_ctx
413
496
414
497
TestLinalg ._default_dtype_check_enabled = True
415
498
instantiate_device_type_tests (TestLinalg , globals (), only_for = ("xpu" ), allow_xpu = True )
0 commit comments