@@ -3307,25 +3307,25 @@ struct AtomicRMWOpConversion
3307
3307
valueElemNBits == 64 ) &&
3308
3308
" Unexpected width" );
3309
3309
3310
- Value zero;
3311
- llvm::TypeSwitch<mlir::Type>(valueElemTy)
3312
- .Case <mlir::IntegerType>(
3313
- [&](auto ty) { zero = b.int_val (valueElemNBits, 0 ); })
3314
- .Case <mlir::Float16Type>([&](auto ty) { zero = b.f16_val (0 ); })
3315
- .Case <mlir::Float32Type>([&](auto ty) { zero = b.f32_val (0 ); })
3316
- .Case <mlir::Float64Type>([&](auto ty) { zero = b.f64_val (0 ); });
3310
+ Value zero =
3311
+ TypeSwitch<mlir::Type, Value>(valueElemTy)
3312
+ .Case <mlir::IntegerType>(
3313
+ [&](auto ty) { return b.int_val (valueElemNBits, 0 ); })
3314
+ .Case <mlir::Float16Type>([&](auto ) { return b.f16_val (0 ); })
3315
+ .Case <mlir::BFloat16Type>([&](auto ) { return b.bf16_val (0 ); })
3316
+ .Case <mlir::Float32Type>([&](auto ) { return b.f32_val (0 ); })
3317
+ .Case <mlir::Float64Type>([&](auto ) { return b.f64_val (0 ); });
3317
3318
3318
3319
// TODO: check device capabilities to avoid unnecessary emulation or
3319
3320
// emit unsupported feature error.
3320
3321
Value ret;
3321
3322
bool support16BitAtomics = moduleOp->hasAttr (
3322
3323
TritonIntelGPUDialect::getSupport16BitAtomicsAttrName ());
3323
3324
if (valueElemNBits == 16 && !support16BitAtomics) {
3324
- op.emitWarning (
3325
- " 'tt.atomic_rmw' op fp16 datatype is not supported in the target "
3326
- " HW, software emulation is an experimental feature (use at own "
3327
- " risk)" );
3328
- Block *endBlock = emulateFp16AtomicRmw (
3325
+ op.emitWarning (" 'tt.atomic_rmw' op fp16/bf16 datatype is not supported "
3326
+ " in the target HW, software emulation is an "
3327
+ " experimental feature (use at own risk)" );
3328
+ Block *endBlock = emulate16BitsAtomicRmw (
3329
3329
rewriter, loc, atomicRmwAttr, valueElemTy, rmwPtr, rmwVal,
3330
3330
maybeAnd (rewriter, loc, b.true_val (), rmwMask), {zero});
3331
3331
ret = endBlock->getArgument (0 );
@@ -3391,10 +3391,10 @@ struct AtomicRMWOpConversion
3391
3391
3392
3392
// Emulate 16-bit atomicrmw through a loop with 32-bit cmpxchg.
3393
3393
// TODO: optimize for the case when rmwMask is a true constant?
3394
- Block *emulateFp16AtomicRmw (ConversionPatternRewriter &rewriter, Location loc ,
3395
- mlir::triton::RMWOp atomicOp, Type valueElemTy ,
3396
- Value rmwPtr , Value rmwVal , Value rmwMask ,
3397
- ArrayRef<Value> ops) const {
3394
+ Block *emulate16BitsAtomicRmw (ConversionPatternRewriter &rewriter,
3395
+ Location loc, mlir::triton::RMWOp atomicOp,
3396
+ Type valueElemTy , Value rmwPtr , Value rmwVal ,
3397
+ Value rmwMask, ArrayRef<Value> ops) const {
3398
3398
auto b = TritonLLVMOpBuilder (loc, rewriter);
3399
3399
Block *insertionBlock = rewriter.getInsertionBlock ();
3400
3400
Block *headerBlock =
0 commit comments