Skip to content

Commit 21f45c4

Browse files
committed
[mlir][AMDGPU] Allow non-contiguous destination memrefs for gather_to_lds
The requirement that the LDS operand is contiguous is overly restrictive because it's perfectly valid to have a subview depend on subgroup IDs that is still subgroup contiguous. We could continue trying to do this verification based on the number of copied elements, but instead this change just opts to clarify the semantics on the op definition.
1 parent 1194353 commit 21f45c4

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,8 @@ def AMDGPU_GatherToLDSOp :
907907
The elements gathered by the subgroup will be written contiguously in order of lane ID
908908
starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16)
909909
types will be zero-padded/extended to 32 bits before being written. 96-bit types
910-
(ex. vector<3xf32>) will be zero-padded to 128 bits before being written.
910+
(ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the
911+
offsets held by lane 0 are used.
911912
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
912913
the size of the data to be transferred and the number of threads in the subgroup.
913914
The transfer type must be a scalar type or a vector type with a single element type.

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,6 @@ LogicalResult GatherToLDSOp::verify() {
518518
MemRefType srcType = cast<MemRefType>(getSrc().getType());
519519
MemRefType dstType = cast<MemRefType>(getDst().getType());
520520

521-
if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
522-
return emitOpError("destination types must be contiguous");
523-
524521
auto elemType = srcType.getElementType();
525522
// Check $src and $dst element types are the same.
526523
if (elemType != dstType.getElementType())

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,13 +539,15 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16
539539
}
540540

541541
// CHECK-LABEL: func @gather_to_lds
542-
func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>) {
542+
func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>, %smem3 : memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
543543
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
544544
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}]
545545
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
546+
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
546547
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
547548
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
548549
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
550+
amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
549551
func.return
550552
}
551553

0 commit comments

Comments
 (0)