Skip to content

Commit 6b29ee9

Browse files
authored
[mlir][amdgpu] Properly handle mismatching memref ranks in amdgpu.gather_to_lds (#149407)
This op doesn't have any rank or indices restrictions on src/dst memrefs, but was using `SameVariadicOperandSize` which was causing issues. Also fix some other issues while we at it.
1 parent b826429 commit 6b29ee9

File tree

5 files changed

+36
-10
lines changed

5 files changed

+36
-10
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def AMDGPU_ScaledExtPackedOp
127127
let summary = "Extend a vector of packed floating point values";
128128

129129
let description = [{
130-
Extend and scale two packed floats in `source[index]` to two floats and
130+
Extend and scale two packed floats in `source[index]` to two floats and
131131
return them.
132132

133133
This rather unusual signature arises from the fact that AMD GPUs cannot
@@ -861,7 +861,7 @@ def AMDGPU_WMMAOp :
861861
}
862862

863863
def AMDGPU_GatherToLDSOp :
864-
AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
864+
AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
865865
Arguments<(ins
866866
Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
867867
Variadic<Index>:$srcIndices,
@@ -966,13 +966,13 @@ def AMDGPU_ScaledMFMAOp :
966966
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
967967

968968
This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
969-
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
970-
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
971-
size.
972-
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
969+
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
970+
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
971+
size.
972+
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
973973
are omitted from this wrapper.
974-
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
975-
double-precision operations on gfx94x and so are not included here.
974+
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
975+
double-precision operations on gfx94x and so are not included here.
976976
}];
977977
let assemblyFormat = [{
978978
`(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
134134
}
135135

136136
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
137+
if (!memorySpace)
138+
return false;
137139
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
138140
return intMemorySpace.getInt() == 3;
139141
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
@@ -142,6 +144,8 @@ static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
142144
}
143145

144146
static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
147+
if (!memorySpace)
148+
return false;
145149
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
146150
return intMemorySpace.getInt() == 7;
147151
if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))

mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,15 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
127127
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
128128
// CHECK: %[[ALLOC:.*]] = memref.alloc()
129129
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
130+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
131+
// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
130132
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
131133
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
132134
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
133135
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
134136
// CHECK: %[[DSTIDX:.*]] = llvm.mul %[[DSTIDX_CAST]], %[[C64]] : i64
135-
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX]]]
137+
// CHECK: %[[DSTIDX1:.*]] = llvm.add %[[DSTIDX]], %[[C0_I64]] : i64
138+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX1]]]
136139
// CHECK: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 4
137140
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
138141
%c0 = arith.constant 0 : index
@@ -151,7 +154,7 @@ func.func @fat_buffer_load_to_rocdl_f32(%global : memref<128x72xf32, #amdgpu_fat
151154
// CHECK: %[[BUFFER_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
152155

153156
// CHECK: %[[C0:.*]] = arith.constant 0 : index
154-
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
157+
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
155158
// CHECK: %[[C12:.*]] = arith.constant 12 : index
156159
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
157160
// CHECK: %[[C32:.*]] = arith.constant 32 : index

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,11 @@ func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : me
222222
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
223223
func.return %0 : vector<8xi6>
224224
}
225+
226+
// -----
227+
228+
func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : memref<32xf16>) {
229+
// expected-error@+1 {{'amdgpu.gather_to_lds' op destination memory address space must be Workgroup}}
230+
amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16>
231+
func.return
232+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,3 +493,14 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16
493493
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
494494
func.return %0 : vector<4xf16>
495495
}
496+
497+
// CHECK-LABEL: func @gather_to_lds
498+
func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>) {
499+
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
500+
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}]
501+
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
502+
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
503+
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
504+
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
505+
func.return
506+
}

0 commit comments

Comments
 (0)