Skip to content

Commit e0160f1

Browse files
alugoreyjithunnair-amd
authored andcommitted
Extend CK gemm/sdpa support to gfx950 (#45)
Update CK for gfx950 (#49) (cherry picked from commit 8ccfc47) (cherry picked from commit b5d5987)
1 parent e62e394 commit e0160f1

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

aten/src/ATen/Context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
419419
if(b == at::ROCmFABackend::Ck) {
420420
static const bool ck_unsupported = []() {
421421
static const std::vector<std::string> archs = {
422-
"gfx90a", "gfx942"
422+
"gfx90a", "gfx942", "gfx950"
423423
};
424424
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
425425
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {

aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,5 @@ struct fmha_bwd_traits
453453
bool is_deterministic;
454454
// TODO: padding check is inside this api
455455
};
456+
template <int Version = 2>
456457
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

0 commit comments

Comments
 (0)