Skip to content

Commit b66fbd6

Browse files
authored
Extend lowering for atomic_rmw to bfloat16 (#4747)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 8e3d971 commit b66fbd6

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,9 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
15071507
pytest.skip("Broken on rocm")
15081508
if is_xpu():
15091509
if (kind, dtype_str) in [("add", "bfloat16")]:
1510-
pytest.skip("FIXME: issue #4375")
1510+
if descriptor == "host":
1511+
pytest.skip("FIXME: issue #4289")
1512+
pytest.skip("FIXME: issue #3914")
15111513

15121514
@triton.jit(debug=True)
15131515
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,25 +3307,25 @@ struct AtomicRMWOpConversion
33073307
valueElemNBits == 64) &&
33083308
"Unexpected width");
33093309

3310-
Value zero;
3311-
llvm::TypeSwitch<mlir::Type>(valueElemTy)
3312-
.Case<mlir::IntegerType>(
3313-
[&](auto ty) { zero = b.int_val(valueElemNBits, 0); })
3314-
.Case<mlir::Float16Type>([&](auto ty) { zero = b.f16_val(0); })
3315-
.Case<mlir::Float32Type>([&](auto ty) { zero = b.f32_val(0); })
3316-
.Case<mlir::Float64Type>([&](auto ty) { zero = b.f64_val(0); });
3310+
Value zero =
3311+
TypeSwitch<mlir::Type, Value>(valueElemTy)
3312+
.Case<mlir::IntegerType>(
3313+
[&](auto ty) { return b.int_val(valueElemNBits, 0); })
3314+
.Case<mlir::Float16Type>([&](auto) { return b.f16_val(0); })
3315+
.Case<mlir::BFloat16Type>([&](auto) { return b.bf16_val(0); })
3316+
.Case<mlir::Float32Type>([&](auto) { return b.f32_val(0); })
3317+
.Case<mlir::Float64Type>([&](auto) { return b.f64_val(0); });
33173318

33183319
// TODO: check device capabilities to avoid unnecessary emulation or
33193320
// emit unsupported feature error.
33203321
Value ret;
33213322
bool support16BitAtomics = moduleOp->hasAttr(
33223323
TritonIntelGPUDialect::getSupport16BitAtomicsAttrName());
33233324
if (valueElemNBits == 16 && !support16BitAtomics) {
3324-
op.emitWarning(
3325-
"'tt.atomic_rmw' op fp16 datatype is not supported in the target "
3326-
"HW, software emulation is an experimental feature (use at own "
3327-
"risk)");
3328-
Block *endBlock = emulateFp16AtomicRmw(
3325+
op.emitWarning("'tt.atomic_rmw' op fp16/bf16 datatype is not supported "
3326+
"in the target HW, software emulation is an "
3327+
"experimental feature (use at own risk)");
3328+
Block *endBlock = emulate16BitsAtomicRmw(
33293329
rewriter, loc, atomicRmwAttr, valueElemTy, rmwPtr, rmwVal,
33303330
maybeAnd(rewriter, loc, b.true_val(), rmwMask), {zero});
33313331
ret = endBlock->getArgument(0);
@@ -3391,10 +3391,10 @@ struct AtomicRMWOpConversion
33913391

33923392
// Emulate 16-bit atomicrmw through a loop with 32-bit cmpxchg.
33933393
// TODO: optimize for the case when rmwMask is a true constant?
3394-
Block *emulateFp16AtomicRmw(ConversionPatternRewriter &rewriter, Location loc,
3395-
mlir::triton::RMWOp atomicOp, Type valueElemTy,
3396-
Value rmwPtr, Value rmwVal, Value rmwMask,
3397-
ArrayRef<Value> ops) const {
3394+
Block *emulate16BitsAtomicRmw(ConversionPatternRewriter &rewriter,
3395+
Location loc, mlir::triton::RMWOp atomicOp,
3396+
Type valueElemTy, Value rmwPtr, Value rmwVal,
3397+
Value rmwMask, ArrayRef<Value> ops) const {
33983398
auto b = TritonLLVMOpBuilder(loc, rewriter);
33993399
Block *insertionBlock = rewriter.getInsertionBlock();
34003400
Block *headerBlock =

0 commit comments

Comments
 (0)