diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 92aacdaef4136..2c646934c11c2 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -907,7 +907,8 @@ def AMDGPU_GatherToLDSOp : The elements gathered by the subgroup will be written contiguously in order of lane ID starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16) types will be zero-padded/extended to 32 bits before being written. 96-bit types - (ex. vector<3xf32>) will be zero-padded to 128 bits before being written. + (ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the + offsets held by lane 0 are used. * `$transferType`: type of the data to be transferred by each thread. This is used to determine the size of the data to be transferred and the number of threads in the subgroup. The transfer type must be a scalar type or a vector type with a single element type. diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 9a0a230e8abca..d7ffdcb58ddb5 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -518,8 +518,8 @@ LogicalResult GatherToLDSOp::verify() { MemRefType srcType = cast(getSrc().getType()); MemRefType dstType = cast(getDst().getType()); - if (!dstType.areTrailingDimsContiguous(dstType.getRank())) - return emitOpError("destination types must be contiguous"); + if (!dstType.areTrailingDimsContiguous(1)) + return emitOpError("destination type inner most dim must be contiguous"); auto elemType = srcType.getElementType(); // Check $src and $dst element types are the same. diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 0d2fd245af9e2..66e7dd4014af9 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -230,3 +230,11 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16> func.return } + +// ----- + +func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : memref<32xf16, strided<[?]>, #gpu.address_space>) { + // expected-error@+1 {{'amdgpu.gather_to_lds' op destination type inner most dim must be contiguous}} + amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16, strided<[?]>, #gpu.address_space> + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index fe78b5365745a..87e11c028c62a 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -539,13 +539,15 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16 } // CHECK-LABEL: func @gather_to_lds -func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space>, %smem2 : memref<32x32xf16, #gpu.address_space>) { +func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space>, %smem2 : memref<32x32xf16, #gpu.address_space>, %smem3 : memref, #gpu.address_space>) { // CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}] // CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}] // CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}] + // CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}] amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space> amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space> amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space> + amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref, #gpu.address_space> func.return }