Skip to content

Commit 35f1e76

Browse files
authored
Reland of "[ROCm] change preferred blas lib defaults (pytorch#150249)"" (pytorch#150707)
Revert "Revert "[ROCm] change preferred blas lib defaults (pytorch#150249)" (pytorch#150658)" This reverts commit 06c6a81.
1 parent a6321d6 commit 35f1e76

File tree

7 files changed

+101
-11
lines changed

7 files changed

+101
-11
lines changed

aten/src/ATen/BlasBackend.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
namespace at {
99

10-
enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck };
10+
enum class BlasBackend : int8_t { Default, Cublas, Cublaslt, Ck };
1111

1212
inline std::string BlasBackendToString(at::BlasBackend backend) {
1313
switch (backend) {
14+
case BlasBackend::Default:
15+
return "at::BlasBackend::Default";
1416
case BlasBackend::Cublas:
1517
return "at::BlasBackend::Cublas";
1618
case BlasBackend::Cublaslt:

aten/src/ATen/Context.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,34 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) {
326326
}
327327

328328
at::BlasBackend Context::blasPreferredBackend() {
329+
// Rather than put logic for interpreting what Default means at every
330+
// call site for blasPreferredBackend(), we set it to an actual value.
331+
if (blas_preferred_backend == at::BlasBackend::Default) {
332+
blas_preferred_backend = at::BlasBackend::Cublas;
329333
#ifdef USE_ROCM
334+
// AMD Instinct targets prefer hipblaslt
335+
static const bool hipblaslt_preferred = []() {
336+
static const std::vector<std::string> archs = {
337+
"gfx90a", "gfx942",
338+
#if ROCM_VERSION >= 60500
339+
"gfx950"
340+
#endif
341+
};
342+
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
343+
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
344+
return false;
345+
}
346+
}
347+
return true;
348+
}();
349+
if (hipblaslt_preferred) {
350+
blas_preferred_backend = at::BlasBackend::Cublaslt;
351+
}
352+
#endif
353+
}
354+
355+
#ifdef USE_ROCM
356+
// hipblaslt support for all archs is not as complete as hipblas
330357
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
331358
static const bool hipblaslt_unsupported = []() {
332359
static const std::vector<std::string> archs = {
@@ -338,7 +365,7 @@ at::BlasBackend Context::blasPreferredBackend() {
338365
"gfx950"
339366
#endif
340367
};
341-
for (auto index: c10::irange(getNumGPUs())) {
368+
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
342369
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
343370
TORCH_WARN_ONCE(
344371
"Attempting to use hipBLASLt on an unsupported architecture! "
@@ -365,7 +392,7 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
365392
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
366393
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
367394
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
368-
if (b != at::BlasBackend::Cublas) {
395+
if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
369396
TORCH_WARN_ONCE(
370397
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
371398
"If you see any error or unexpected behavior when this flag is set "
@@ -391,7 +418,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
391418
static const std::vector<std::string> archs = {
392419
"gfx90a", "gfx942"
393420
};
394-
for (auto index: c10::irange(getNumGPUs())) {
421+
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
395422
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
396423
TORCH_WARN_ONCE(
397424
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");

aten/src/ATen/Context.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,17 +446,15 @@ class TORCH_API Context {
446446
bool allow_tf32_onednn = false;
447447
bool enabled_nnpack = true;
448448
at::LinalgBackend linalg_preferred_backend =
449-
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
449+
(c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
450+
c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias
450451
? at::LinalgBackend::Cusolver
451452
: at::LinalgBackend::Default;
452453
at::BlasBackend blas_preferred_backend =
453-
#ifdef USE_ROCM
454-
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
455-
#else
456-
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
457-
#endif
454+
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
455+
c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias
458456
? at::BlasBackend::Cublaslt
459-
: at::BlasBackend::Cublas;
457+
: at::BlasBackend::Default;
460458
at::ROCmFABackend rocm_fa_preferred_backend =
461459
c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
462460
? at::ROCmFABackend::Ck

test/test_cuda.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,64 @@ def test_serialization_array_with_storage(self):
586586
q_copy[1].fill_(10)
587587
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
588588

589+
@setBlasBackendsToDefaultFinally
590+
def test_preferred_blas_library_settings(self):
591+
def _check_default():
592+
default = torch.backends.cuda.preferred_blas_library()
593+
if torch.version.cuda:
594+
# CUDA logic is easy, it's always cublas
595+
self.assertTrue(default == torch._C._BlasBackend.Cublas)
596+
else:
597+
# ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else
598+
gcn_arch = str(
599+
torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0]
600+
)
601+
if gcn_arch in ["gfx90a", "gfx942", "gfx950"]:
602+
self.assertTrue(default == torch._C._BlasBackend.Cublaslt)
603+
else:
604+
self.assertTrue(default == torch._C._BlasBackend.Cublas)
605+
606+
_check_default()
607+
# "Default" can be set but is immediately reset internally to the actual default value.
608+
self.assertTrue(
609+
torch.backends.cuda.preferred_blas_library("default")
610+
!= torch._C._BlasBackend.Default
611+
)
612+
_check_default()
613+
self.assertTrue(
614+
torch.backends.cuda.preferred_blas_library("cublas")
615+
== torch._C._BlasBackend.Cublas
616+
)
617+
self.assertTrue(
618+
torch.backends.cuda.preferred_blas_library("hipblas")
619+
== torch._C._BlasBackend.Cublas
620+
)
621+
# check bad strings
622+
with self.assertRaisesRegex(
623+
RuntimeError,
624+
"Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.",
625+
):
626+
torch.backends.cuda.preferred_blas_library("unknown")
627+
# check bad input type
628+
with self.assertRaisesRegex(RuntimeError, "Unknown input value type."):
629+
torch.backends.cuda.preferred_blas_library(1.0)
630+
# check env var override
631+
custom_envs = [
632+
{"TORCH_BLAS_PREFER_CUBLASLT": "1"},
633+
{"TORCH_BLAS_PREFER_HIPBLASLT": "1"},
634+
]
635+
test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())"
636+
for env_config in custom_envs:
637+
env = os.environ.copy()
638+
for key, value in env_config.items():
639+
env[key] = value
640+
r = (
641+
subprocess.check_output([sys.executable, "-c", test_script], env=env)
642+
.decode("ascii")
643+
.strip()
644+
)
645+
self.assertEqual("_BlasBackend.Cublaslt", r)
646+
589647
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
590648
@setBlasBackendsToDefaultFinally
591649
def test_cublas_workspace_explicit_allocation(self):

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,7 @@ def _get_blas_preferred_backend() -> torch._C._BlasBackend: ...
13091309
def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ...
13101310

13111311
class _BlasBackend:
1312+
Default: _BlasBackend
13121313
Cublas: _BlasBackend
13131314
Cublaslt: _BlasBackend
13141315
Ck: _BlasBackend

torch/backends/cuda/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def preferred_linalg_library(
218218

219219

220220
_BlasBackends = {
221+
"default": torch._C._BlasBackend.Default,
221222
"cublas": torch._C._BlasBackend.Cublas,
223+
"hipblas": torch._C._BlasBackend.Cublas, # alias
222224
"cublaslt": torch._C._BlasBackend.Cublaslt,
223225
"hipblaslt": torch._C._BlasBackend.Cublaslt, # alias
224226
"ck": torch._C._BlasBackend.Ck,
@@ -241,6 +243,7 @@ def preferred_blas_library(
241243
* If `"cublas"` is set then cuBLAS will be used wherever possible.
242244
* If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
243245
* If `"ck"` is set then CK will be used wherever possible.
246+
* If `"default"` (the default) is set then heuristics will be used to pick between the other options.
244247
* When no input is given, this function returns the currently preferred library.
245248
* User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
246249
globally.

torch/csrc/Module.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,7 @@ Call this whenever a new thread is created in order to propagate values from
22432243
});
22442244

22452245
py::enum_<at::BlasBackend>(py_module, "_BlasBackend")
2246+
.value("Default", at::BlasBackend::Default)
22462247
.value("Cublas", at::BlasBackend::Cublas)
22472248
.value("Cublaslt", at::BlasBackend::Cublaslt)
22482249
.value("Ck", at::BlasBackend::Ck);

0 commit comments

Comments
 (0)