Skip to content

Commit 888bb4c

Browse files
naromero77amdamathewc
authored andcommitted
[ROCm][TunableOp] Unit test for TunableOp BLAS logging. (pytorch#148982)
Add unit test for new TunableOp BLAS logging feature. Requires this PR to be merged in first: pytorch#148979 Pull Request resolved: pytorch#148982 Approved by: https://github.com/jeffdaily
1 parent 7460c75 commit 888bb4c

File tree

1 file changed

+108
-1
lines changed

1 file changed

+108
-1
lines changed

test/test_linalg.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,17 @@ def set_tunableop_defaults():
7878
torch.cuda.tunable.set_max_tuning_iterations(100)
7979
torch.cuda.tunable.set_rotating_buffer_size(-1)
8080

81-
def tunableop_matmul(device, dtype):
81+
def tunableop_matmul(device, dtype, offline=False):
8282
# Helper function to test TunableOp in a subprocess
8383
# requires helper function since lambda function
8484
# not supported by multiprocessing module
8585
import os
8686
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
87+
88+
if offline:
89+
torch.cuda.tunable.tuning_enable(False)
90+
torch.cuda.tunable.record_untuned_enable(True)
91+
8792
torch.cuda.tunable.set_max_tuning_duration(1)
8893
A = torch.randn((17, 17), device=device, dtype=dtype)
8994
B = torch.randn((17, 17), device=device, dtype=dtype)
@@ -5661,6 +5666,108 @@ def test_tf32_offline_tunableop(self, device, dtype):
56615666
except (FileNotFoundError, PermissionError):
56625667
pass
56635668

5669+
@onlyCUDA
5670+
@skipCUDAIfNotRocm
5671+
@dtypes(torch.float16)
5672+
def test_blaslog_tunableop(self, device, dtype):
5673+
# Test that PYTORCH_TUNABLEOP_BLAS_LOG=1 gives
5674+
# an additional column of data with the BLAS
5675+
# parameters in offline and online tuning.
5676+
#
5677+
# We record GEMMs and then check that the
5678+
# BLAS_PARAMS appear in
5679+
# tunableop_untuned CSV file
5680+
# and
5681+
# tunableop_results CSV file
5682+
#
5683+
# NOTE: This is done in a subproceses
5684+
# because in the main process
5685+
# PYTORCH_TUNABLEOP_BLAS_LOG has already
5686+
# been deactivated and its value is sticky
5687+
import os
5688+
import multiprocessing as mp
5689+
5690+
set_tunableop_defaults()
5691+
ordinal = torch.cuda.current_device()
5692+
5693+
result_filename = f"tunableop_results{ordinal}.csv"
5694+
untuned_filename = f"tunableop_untuned{ordinal}.csv"
5695+
5696+
# Test in try-finally block to avoid leaking state
5697+
# if test is interrupted.
5698+
try:
5699+
os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1")
5700+
5701+
# Offline Tuning case in a subprocess
5702+
5703+
# force=True needed according to:
5704+
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
5705+
# This is because a different test in this process could have
5706+
# already set the start method
5707+
mp.set_start_method("spawn", force=True)
5708+
5709+
p = mp.Process(target=tunableop_matmul, args=(device, dtype, True))
5710+
p.start()
5711+
p.join()
5712+
5713+
# Make sure the results file exists and that it is not zero
5714+
self.assertTrue(os.path.exists(untuned_filename))
5715+
self.assertTrue(os.path.getsize(untuned_filename) > 0)
5716+
5717+
# Check that the BLAS PARAMS are in the CSV file
5718+
import csv
5719+
with open(untuned_filename) as file:
5720+
reader = csv.reader(file)
5721+
first_row = next(reader)
5722+
# Check for extra column
5723+
self.assertGreater(len(first_row), 3)
5724+
# Check for YAML entry to the right of
5725+
# BLAS PARAMS
5726+
self.assertTrue("{ function:" in first_row[2])
5727+
5728+
# Online tuning case in a subprocess
5729+
5730+
# force=True needed according to:
5731+
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
5732+
# This is because a different test in this process could have
5733+
# already set the start method
5734+
mp.set_start_method("spawn", force=True)
5735+
5736+
p = mp.Process(target=tunableop_matmul, args=(device, dtype, False))
5737+
p.start()
5738+
p.join()
5739+
5740+
# Make sure the results file exists and that it is not zero
5741+
self.assertTrue(os.path.exists(result_filename))
5742+
self.assertGreater(os.path.getsize(result_filename), 0)
5743+
5744+
# Check that there BLAS PARAMS are in the CSV file
5745+
with open(result_filename) as file:
5746+
reader = csv.reader(file)
5747+
for _ in range(5): # Skip the first 5 lines for the validator
5748+
next(reader, None)
5749+
# Check for extra column
5750+
first_row = next(reader)
5751+
self.assertGreater(len(first_row), 5)
5752+
# Check for YAML entry to the right of
5753+
# BLAS PARAMS
5754+
self.assertTrue("{ function:" in first_row[4])
5755+
5756+
finally:
5757+
# undo all the environment variables set
5758+
try:
5759+
del os.environ["PYTORCH_TUNABLEOP_BLAS_LOG"]
5760+
except KeyError:
5761+
pass
5762+
5763+
# clean up, remove any files that were generated
5764+
for filename in [untuned_filename, result_filename]:
5765+
try:
5766+
os.remove(filename)
5767+
# NB: The file is locked on Windows
5768+
except (FileNotFoundError, PermissionError):
5769+
pass
5770+
56645771
@dtypes(torch.float, torch.complex64)
56655772
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
56665773
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)

0 commit comments

Comments
 (0)