@@ -586,64 +586,6 @@ def test_serialization_array_with_storage(self):
586586 q_copy [1 ].fill_ (10 )
587587 self .assertEqual (q_copy [3 ], torch .cuda .IntStorage (10 ).fill_ (10 ))
588588
589- @setBlasBackendsToDefaultFinally
590- def test_preferred_blas_library_settings (self ):
591- def _check_default ():
592- default = torch .backends .cuda .preferred_blas_library ()
593- if torch .version .cuda :
594- # CUDA logic is easy, it's always cublas
595- self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
596- else :
597- # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else
598- gcn_arch = str (
599- torch .cuda .get_device_properties (0 ).gcnArchName .split (":" , 1 )[0 ]
600- )
601- if gcn_arch in ["gfx90a" , "gfx942" , "gfx950" ]:
602- self .assertTrue (default == torch ._C ._BlasBackend .Cublaslt )
603- else :
604- self .assertTrue (default == torch ._C ._BlasBackend .Cublas )
605-
606- _check_default ()
607- # "Default" can be set but is immediately reset internally to the actual default value.
608- self .assertTrue (
609- torch .backends .cuda .preferred_blas_library ("default" )
610- != torch ._C ._BlasBackend .Default
611- )
612- _check_default ()
613- self .assertTrue (
614- torch .backends .cuda .preferred_blas_library ("cublas" )
615- == torch ._C ._BlasBackend .Cublas
616- )
617- self .assertTrue (
618- torch .backends .cuda .preferred_blas_library ("hipblas" )
619- == torch ._C ._BlasBackend .Cublas
620- )
621- # check bad strings
622- with self .assertRaisesRegex (
623- RuntimeError ,
624- "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck." ,
625- ):
626- torch .backends .cuda .preferred_blas_library ("unknown" )
627- # check bad input type
628- with self .assertRaisesRegex (RuntimeError , "Unknown input value type." ):
629- torch .backends .cuda .preferred_blas_library (1.0 )
630- # check env var override
631- custom_envs = [
632- {"TORCH_BLAS_PREFER_CUBLASLT" : "1" },
633- {"TORCH_BLAS_PREFER_HIPBLASLT" : "1" },
634- ]
635- test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())"
636- for env_config in custom_envs :
637- env = os .environ .copy ()
638- for key , value in env_config .items ():
639- env [key ] = value
640- r = (
641- subprocess .check_output ([sys .executable , "-c" , test_script ], env = env )
642- .decode ("ascii" )
643- .strip ()
644- )
645- self .assertEqual ("_BlasBackend.Cublaslt" , r )
646-
647589 @unittest .skipIf (TEST_CUDAMALLOCASYNC , "temporarily disabled for async" )
648590 @setBlasBackendsToDefaultFinally
649591 def test_cublas_workspace_explicit_allocation (self ):
0 commit comments