@@ -262,10 +262,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
262
262
mlir::MLIRContext *context,
263
263
DenseMap<Value, SetVector<Operation *>> &assumptions,
264
264
ModuleAxisInfoAnalysis &axisAnalysisPass,
265
- std::shared_ptr<DataFlowSolver> solver)
265
+ std::shared_ptr<DataFlowSolver> solver, ISAFamily isaFamily )
266
266
: mlir::OpRewritePattern<triton::AtomicRMWOp>(context),
267
267
assumptions (assumptions), axisAnalysisPass(axisAnalysisPass),
268
- solver (std::move(solver)) {}
268
+ solver (std::move(solver)), isaFamily(isaFamily) {}
269
269
270
270
mlir::LogicalResult
271
271
matchAndRewrite (triton::AtomicRMWOp op,
@@ -323,6 +323,14 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
323
323
}
324
324
LDBG (" RMW supported type" );
325
325
326
+ // float16 is the only 16-bit dtype supported by buffer atomic fadd on
327
+ // gfx942
328
+ if (isaFamily == ISAFamily::CDNA3 && checkType.isBF16 () &&
329
+ atomicRmwOp == RMWOp::FADD) {
330
+ return rewriter.notifyMatchFailure (op, " RMW FADD does not support bf16" );
331
+ }
332
+ LDBG (" RMW FADD supported 16-bit type" );
333
+
326
334
auto vecSize = getVectorSize (ptr, axisAnalysisPass);
327
335
// f16/bf16 dtypes could only be efficiently calculated using instructions
328
336
// that pack 2 elements (e.g. @llvm.amdgcn.raw.buffer.atomic.fadd.v2f16)
@@ -387,6 +395,7 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
387
395
DenseMap<Value, SetVector<Operation *>> assumptions;
388
396
ModuleAxisInfoAnalysis &axisAnalysisPass;
389
397
std::shared_ptr<DataFlowSolver> solver;
398
+ ISAFamily isaFamily;
390
399
};
391
400
392
401
// Workaround to allow static_assert(false) on older compilers as it was
@@ -541,9 +550,11 @@ class TritonAMDGPUConvertToBufferOpsPass
541
550
// Gate buffer atomics behind CDNA3 for now
542
551
// GFX942-specific assumptions regarding cache coherence are made when
543
552
// lowering to LLVM
544
- if (ISAFamily::CDNA3 == triton::AMD::deduceISAFamily (archGenerationName))
553
+ triton::AMD::ISAFamily isaFamily =
554
+ triton::AMD::deduceISAFamily (archGenerationName);
555
+ if (ISAFamily::CDNA3 == isaFamily)
545
556
patterns.add <ConvertTritonAtomicRMWOpToBufferAtomicRMW>(
546
- context, assumptions, axisInfoAnalysis, solver);
557
+ context, assumptions, axisInfoAnalysis, solver, isaFamily );
547
558
548
559
if (applyPatternsGreedily (mod, std::move (patterns)).failed ())
549
560
signalPassFailure ();
0 commit comments