Skip to content

Commit 226850c

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (pytorch#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark. Pull Request resolved: pytorch#167734 Approved by: https://github.com/eqy, https://github.com/cyyever
1 parent f8a2ce3 commit 226850c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

cmake/Codegen.cmake

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
118118
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
119119
endif()
120120
endif()
121+
if("${_arch}" STREQUAL "121a")
122+
if(_existing_arch_flags MATCHES ".*compute_120.*")
123+
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
124+
endif()
125+
endif()
121126
endforeach()
122127
list(JOIN _file_compile_flags " " _file_compile_flags)
123128

@@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
126131

127132
_BUILD_FOR_ADDITIONAL_ARCHS(
128133
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
129-
"89;90a;100a;103a;120a")
134+
"89;90a;100a;103a;120a;121a")
130135
_BUILD_FOR_ADDITIONAL_ARCHS(
131136
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
132137
"90a")

0 commit comments

Comments
 (0)