Skip to content

Commit 5e00f35

Browse files
authored
[AMD] Remove bfloat16 fadd buffer atomic for gfx942 (#7011)
The only supported 16-bit buffer atomic fadd dtype on gfx942 is float16 (BUFFER_ATOMIC_PK_ADD_F16). There seems to be no corresponding instruction for bf16. Without this PR, the following kernel ``` import torch import triton import triton.language as tl @triton.jit def atomic_add_bf16(X, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE t1 = tl.full((BLOCK_SIZE, ), 1, dtype=tl.bfloat16) offsets = block_start + tl.arange(0, BLOCK_SIZE) tl.atomic_add(X + offsets, t1) X = torch.tensor([0, 1] * 256, device='cuda', dtype=torch.bfloat16) Z = torch.tensor([1, 2] * 256, device='cuda', dtype=torch.bfloat16) k = atomic_add_bf16[(1,)](X, 2 * 256) assert (torch.equal(X, Z)) print("Success!") ``` will fail in instruction selection >~/triton]$ python ~/kernels/atomic_add_bf16.py > LLVM ERROR: Cannot select: t56: v2bf16,ch = BUFFER_ATOMIC_FADD<(volatile dereferenceable load store (s32) on %ir.10, align 1, addrspace 8)> # D:1 t34, t37, t49, Constant:i32<0>, t20, Constant:i32<0>, TargetConstant:i32<0>, TargetConstant:i32<0>, TargetConstant:i1<0>, atomic_add_bf16.py:11:31 >..
1 parent c049167 commit 5e00f35

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
546546

547547
// -----
548548

549+
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
550+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
551+
// CHECK-LABEL: atomic_add_bf16
552+
tt.func public @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
553+
%cst = arith.constant dense<true> : tensor<512xi1, #blocked>
554+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<512xbf16, #blocked>
555+
%c512_i32 = arith.constant 512 : i32
556+
%0 = tt.get_program_id x : i32
557+
%1 = arith.muli %0, %c512_i32 : i32
558+
%2 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
559+
%3 = tt.addptr %arg0, %1 : !tt.ptr<bf16>, i32
560+
%4 = tt.splat %3 : !tt.ptr<bf16> -> tensor<512x!tt.ptr<bf16>, #blocked>
561+
%5 = tt.addptr %4, %2 : tensor<512x!tt.ptr<bf16>, #blocked>, tensor<512xi32, #blocked>
562+
// CHECK-NOT: amdgpu.buffer_atomic_rmw
563+
%6 = tt.atomic_rmw fadd, acq_rel, gpu, %5, %cst_0, %cst : (tensor<512x!tt.ptr<bf16>, #blocked>, tensor<512xbf16, #blocked>, tensor<512xi1, #blocked>) -> tensor<512xbf16, #blocked>
564+
tt.return
565+
}
566+
}
567+
568+
// -----
569+
549570
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
550571
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
551572
// CHECK-LABEL: assume_positive_offset_buffer_atomic

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
262262
mlir::MLIRContext *context,
263263
DenseMap<Value, SetVector<Operation *>> &assumptions,
264264
ModuleAxisInfoAnalysis &axisAnalysisPass,
265-
std::shared_ptr<DataFlowSolver> solver)
265+
std::shared_ptr<DataFlowSolver> solver, ISAFamily isaFamily)
266266
: mlir::OpRewritePattern<triton::AtomicRMWOp>(context),
267267
assumptions(assumptions), axisAnalysisPass(axisAnalysisPass),
268-
solver(std::move(solver)) {}
268+
solver(std::move(solver)), isaFamily(isaFamily) {}
269269

270270
mlir::LogicalResult
271271
matchAndRewrite(triton::AtomicRMWOp op,
@@ -323,6 +323,14 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
323323
}
324324
LDBG("RMW supported type");
325325

326+
// float16 is the only 16-bit dtype supported by buffer atomic fadd on
327+
// gfx942
328+
if (isaFamily == ISAFamily::CDNA3 && checkType.isBF16() &&
329+
atomicRmwOp == RMWOp::FADD) {
330+
return rewriter.notifyMatchFailure(op, "RMW FADD does not support bf16");
331+
}
332+
LDBG("RMW FADD supported 16-bit type");
333+
326334
auto vecSize = getVectorSize(ptr, axisAnalysisPass);
327335
// f16/bf16 dtypes could only be efficiently calculated using instructions
328336
// that pack 2 elements (e.g. @llvm.amdgcn.raw.buffer.atomic.fadd.v2f16)
@@ -387,6 +395,7 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
387395
DenseMap<Value, SetVector<Operation *>> assumptions;
388396
ModuleAxisInfoAnalysis &axisAnalysisPass;
389397
std::shared_ptr<DataFlowSolver> solver;
398+
ISAFamily isaFamily;
390399
};
391400

392401
// Workaround to allow static_assert(false) on older compilers as it was
@@ -541,9 +550,11 @@ class TritonAMDGPUConvertToBufferOpsPass
541550
// Gate buffer atomics behind CDNA3 for now
542551
// GFX942-specific assumptions regarding cache coherence are made when
543552
// lowering to LLVM
544-
if (ISAFamily::CDNA3 == triton::AMD::deduceISAFamily(archGenerationName))
553+
triton::AMD::ISAFamily isaFamily =
554+
triton::AMD::deduceISAFamily(archGenerationName);
555+
if (ISAFamily::CDNA3 == isaFamily)
545556
patterns.add<ConvertTritonAtomicRMWOpToBufferAtomicRMW>(
546-
context, assumptions, axisInfoAnalysis, solver);
557+
context, assumptions, axisInfoAnalysis, solver, isaFamily);
547558

548559
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
549560
signalPassFailure();

0 commit comments

Comments
 (0)