Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def AMDGPU_RawBufferAtomicCmpswapOp :
def AMDGPU_RawBufferAtomicFaddOp :
AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>,
AttrSizedOperandSegments]>,
Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16]>]>:$value,
Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$value,
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
Variadic<I32>:$indices,
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
Expand Down
9 changes: 1 addition & 8 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
// so bitcast any floats to integers. On top of all this, cast bfloat
// (vectors) to i16 since the backend doesn't currently support bfloat on
// these operations.
// so bitcast any floats to integers.
Type llvmBufferValType = llvmWantedDataType;
if (wantedDataType.isBF16())
llvmBufferValType = rewriter.getI16Type();
if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
if (wantedVecType.getElementType().isBF16())
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
if (atomicCmpData) {
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: mem
func.return
}

// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16
func.func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16(%value: vector<2xbf16>, %buf: memref<64xbf16>, %idx: i32) {
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
// CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xbf16>
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xbf16> -> memref<64xbf16>, i32
func.return
}

// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32
func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
Expand Down
Loading