diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 7a8d02be530e..eafaff8d0dcc 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -471,8 +471,11 @@ at::BlasBackend Context::blasPreferredBackend() { static const bool hipblaslt_preferred = []() { static const std::vector archs = { "gfx90a", "gfx942", -#if ROCM_VERSION >= 60400 - "gfx1200", "gfx1201", +#if ROCM_VERSION >= 60300 + "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201", +#endif +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", #endif #if ROCM_VERSION >= 60500 "gfx950" @@ -502,7 +505,10 @@ at::BlasBackend Context::blasPreferredBackend() { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", + "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201", "gfx908", +#endif +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", #endif #if ROCM_VERSION >= 60500 "gfx950" diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 23447c7e09b3..49356f8c79bc 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -283,7 +283,10 @@ static bool isSupportedHipLtROCmArch(int index) { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", + "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201", "gfx908", +#endif +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", #endif #if ROCM_VERSION >= 60500 "gfx950"