@@ -78,12 +78,17 @@ def set_tunableop_defaults():
7878 torch .cuda .tunable .set_max_tuning_iterations (100 )
7979 torch .cuda .tunable .set_rotating_buffer_size (- 1 )
8080
81- def tunableop_matmul (device , dtype ):
81+ def tunableop_matmul (device , dtype , offline = False ):
8282 # Helper function to test TunableOp in a subprocess
8383 # requires helper function since lambda function
8484 # not supported by multiprocessing module
8585 import os
8686 os .environ ["PYTORCH_TUNABLEOP_ENABLED" ] = "1"
87+
88+ if offline :
89+ torch .cuda .tunable .tuning_enable (False )
90+ torch .cuda .tunable .record_untuned_enable (True )
91+
8792 torch .cuda .tunable .set_max_tuning_duration (1 )
8893 A = torch .randn ((17 , 17 ), device = device , dtype = dtype )
8994 B = torch .randn ((17 , 17 ), device = device , dtype = dtype )
@@ -5661,6 +5666,108 @@ def test_tf32_offline_tunableop(self, device, dtype):
56615666 except (FileNotFoundError , PermissionError ):
56625667 pass
56635668
5669+ @onlyCUDA
5670+ @skipCUDAIfNotRocm
5671+ @dtypes (torch .float16 )
5672+ def test_blaslog_tunableop (self , device , dtype ):
5673+ # Test that PYTORCH_TUNABLEOP_BLAS_LOG=1 gives
5674+ # an additional column of data with the BLAS
5675+ # parameters in offline and online tuning.
5676+ #
5677+ # We record GEMMs and then check that the
5678+ # BLAS_PARAMS appear in
5679+ # tunableop_untuned CSV file
5680+ # and
5681+ # tunableop_results CSV file
5682+ #
5683+ # NOTE: This is done in a subproceses
5684+ # because in the main process
5685+ # PYTORCH_TUNABLEOP_BLAS_LOG has already
5686+ # been deactivated and its value is sticky
5687+ import os
5688+ import multiprocessing as mp
5689+
5690+ set_tunableop_defaults ()
5691+ ordinal = torch .cuda .current_device ()
5692+
5693+ result_filename = f"tunableop_results{ ordinal } .csv"
5694+ untuned_filename = f"tunableop_untuned{ ordinal } .csv"
5695+
5696+ # Test in try-finally block to avoid leaking state
5697+ # if test is interrupted.
5698+ try :
5699+ os .putenv ("PYTORCH_TUNABLEOP_BLAS_LOG" , "1" )
5700+
5701+ # Offline Tuning case in a subprocess
5702+
5703+ # force=True needed according to:
5704+ # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
5705+ # This is because a different test in this process could have
5706+ # already set the start method
5707+ mp .set_start_method ("spawn" , force = True )
5708+
5709+ p = mp .Process (target = tunableop_matmul , args = (device , dtype , True ))
5710+ p .start ()
5711+ p .join ()
5712+
5713+ # Make sure the results file exists and that it is not zero
5714+ self .assertTrue (os .path .exists (untuned_filename ))
5715+ self .assertTrue (os .path .getsize (untuned_filename ) > 0 )
5716+
5717+ # Check that the BLAS PARAMS are in the CSV file
5718+ import csv
5719+ with open (untuned_filename ) as file :
5720+ reader = csv .reader (file )
5721+ first_row = next (reader )
5722+ # Check for extra column
5723+ self .assertGreater (len (first_row ), 3 )
5724+ # Check for YAML entry to the right of
5725+ # BLAS PARAMS
5726+ self .assertTrue ("{ function:" in first_row [2 ])
5727+
5728+ # Online tuning case in a subprocess
5729+
5730+ # force=True needed according to:
5731+ # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
5732+ # This is because a different test in this process could have
5733+ # already set the start method
5734+ mp .set_start_method ("spawn" , force = True )
5735+
5736+ p = mp .Process (target = tunableop_matmul , args = (device , dtype , False ))
5737+ p .start ()
5738+ p .join ()
5739+
5740+ # Make sure the results file exists and that it is not zero
5741+ self .assertTrue (os .path .exists (result_filename ))
5742+ self .assertGreater (os .path .getsize (result_filename ), 0 )
5743+
5744+ # Check that there BLAS PARAMS are in the CSV file
5745+ with open (result_filename ) as file :
5746+ reader = csv .reader (file )
5747+ for _ in range (5 ): # Skip the first 5 lines for the validator
5748+ next (reader , None )
5749+ # Check for extra column
5750+ first_row = next (reader )
5751+ self .assertGreater (len (first_row ), 5 )
5752+ # Check for YAML entry to the right of
5753+ # BLAS PARAMS
5754+ self .assertTrue ("{ function:" in first_row [4 ])
5755+
5756+ finally :
5757+ # undo all the environment variables set
5758+ try :
5759+ del os .environ ["PYTORCH_TUNABLEOP_BLAS_LOG" ]
5760+ except KeyError :
5761+ pass
5762+
5763+ # clean up, remove any files that were generated
5764+ for filename in [untuned_filename , result_filename ]:
5765+ try :
5766+ os .remove (filename )
5767+ # NB: The file is locked on Windows
5768+ except (FileNotFoundError , PermissionError ):
5769+ pass
5770+
56645771 @dtypes (torch .float , torch .complex64 )
56655772 def test_matmul_out_kernel_errors_with_autograd (self , device , dtype ):
56665773 a = torch .empty ((256 , 512 ), device = device , dtype = dtype , requires_grad = True ).unsqueeze (0 )
0 commit comments