Skip to content

Commit dda01c1

Browse files
committed
Restrict inner most dim
1 parent 21f45c4 commit dda01c1

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ LogicalResult GatherToLDSOp::verify() {
518518
MemRefType srcType = cast<MemRefType>(getSrc().getType());
519519
MemRefType dstType = cast<MemRefType>(getDst().getType());
520520

521+
if (!dstType.areTrailingDimsContiguous(1))
522+
return emitOpError("destination type inner most dim must be contiguous");
523+
521524
auto elemType = srcType.getElementType();
522525
// Check $src and $dst element types are the same.
523526
if (elemType != dstType.getElementType())

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,11 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
230230
amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16>
231231
func.return
232232
}
233+
234+
// -----
235+
236+
func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : memref<32xf16, strided<[?]>, #gpu.address_space<workgroup>>) {
237+
// expected-error@+1 {{'amdgpu.gather_to_lds' op destination type inner most dim must be contiguous}}
238+
amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16, strided<[?]>, #gpu.address_space<workgroup>>
239+
func.return
240+
}

0 commit comments

Comments
 (0)