@@ -586,6 +586,64 @@ def test_serialization_array_with_storage(self):
586586 q_copy [1 ].fill_ (10 )
587587 self .assertEqual (q_copy [3 ], torch .cuda .IntStorage (10 ).fill_ (10 ))
588588
589+ @setBlasBackendsToDefaultFinally
590+ def test_preferred_blas_library_settings (self ):
591+ def _check_default ():
592+ default = torch .backends .cuda .preferred_blas_library ()
593+ if torch .version .cuda :
594+ # CUDA logic is easy, it's always cublas
595+ self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
596+ else :
597+ # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else
598+ gcn_arch = str (
599+ torch .cuda .get_device_properties (0 ).gcnArchName .split (":" , 1 )[0 ]
600+ )
601+ if gcn_arch in ["gfx90a" , "gfx942" , "gfx950" ]:
602+ self .assertTrue (default == torch ._C ._BlasBackend .Cublaslt )
603+ else :
604+ self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
605+
606+ _check_default ()
607+ # "Default" can be set but is immediately reset internally to the actual default value.
608+ self .assertTrue (
609+ torch .backends .cuda .preferred_blas_library ("default" )
610+ != torch ._C ._BlasBackend .Default
611+ )
612+ _check_default ()
613+ self .assertTrue (
614+ torch .backends .cuda .preferred_blas_library ("cublas" )
615+ == torch ._C ._BlasBackend .Cublas
616+ )
617+ self .assertTrue (
618+ torch .backends .cuda .preferred_blas_library ("hipblas" )
619+ == torch ._C ._BlasBackend .Cublas
620+ )
621+ # check bad strings
622+ with self .assertRaisesRegex (
623+ RuntimeError ,
624+ "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck." ,
625+ ):
626+ torch .backends .cuda .preferred_blas_library ("unknown" )
627+ # check bad input type
628+ with self .assertRaisesRegex (RuntimeError , "Unknown input value type." ):
629+ torch .backends .cuda .preferred_blas_library (1.0 )
630+ # check env var override
631+ custom_envs = [
632+ {"TORCH_BLAS_PREFER_CUBLASLT" : "1" },
633+ {"TORCH_BLAS_PREFER_HIPBLASLT" : "1" },
634+ ]
635+ test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())"
636+ for env_config in custom_envs :
637+ env = os .environ .copy ()
638+ for key , value in env_config .items ():
639+ env [key ] = value
640+ r = (
641+ subprocess .check_output ([sys .executable , "-c" , test_script ], env = env )
642+ .decode ("ascii" )
643+ .strip ()
644+ )
645+ self .assertEqual ("_BlasBackend.Cublaslt" , r )
646+
589647 @unittest .skipIf (TEST_CUDAMALLOCASYNC , "temporarily disabled for async" )
590648 @setBlasBackendsToDefaultFinally
591649 def test_cublas_workspace_explicit_allocation (self ):
0 commit comments