Skip to content

Commit 2f023bf

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (pytorch#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark. Pull Request resolved: pytorch#167734 Approved by: https://github.com/eqy, https://github.com/cyyever
1 parent 9760a63 commit 2f023bf

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

cmake/Codegen.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ if(INTERN_BUILD_ATEN_OPS)
118118
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
119119
endif()
120120
endif()
121+
# We will need to gate against CUDA version, sm_121a was introduced in CUDA 12.9
122+
if("${_arch}" STREQUAL "121a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
123+
if(_existing_arch_flags MATCHES ".*compute_120.*")
124+
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
125+
endif()
126+
endif()
121127
endforeach()
122128
list(JOIN _file_compile_flags " " _file_compile_flags)
123129

@@ -126,7 +132,7 @@ if(INTERN_BUILD_ATEN_OPS)
126132

127133
_BUILD_FOR_ADDITIONAL_ARCHS(
128134
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
129-
"89;90a;100a;103a;120a")
135+
"89;90a;100a;103a;120a;121a")
130136
_BUILD_FOR_ADDITIONAL_ARCHS(
131137
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
132138
"90a")

0 commit comments

Comments
 (0)