Skip to content

Commit a55d425

Browse files
[MLIR] Determine contiguousness of memrefs with a dynamic dimension
Memrefs where only the leftmost dimension of the trailing ones to check for contiguity is dynamic can be reasoned about.
1 parent d013556 commit a55d425

File tree

2 files changed

+86
-11
lines changed

2 files changed

+86
-11
lines changed

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,10 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649649
if (!isLastDimUnitStride())
650650
return false;
651651

652-
auto memrefShape = getShape().take_back(n);
652+
if (n == 1)
653+
return true;
654+
655+
auto memrefShape = getShape().take_back(n-1);
653656
if (ShapedType::isDynamicShape(memrefShape))
654657
return false;
655658

@@ -668,7 +671,7 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
668671
// Check whether strides match "flattened" dims.
669672
SmallVector<int64_t> flattenedDims;
670673
auto dimProduct = 1;
671-
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
674+
for (auto dim : llvm::reverse(memrefShape)) {
672675
dimProduct *= dim;
673676
flattenedDims.push_back(dimProduct);
674677
}

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,20 @@ func.func @transfer_read_leading_dynamic_dims(
188188

189189
// -----
190190

191-
// One of the dims to be flattened is dynamic - not supported ATM.
191+
// One of the dims to be flattened is dynamic and not the leftmost - not
192+
// possible to reason whether the memref is contiguous as the dynamic dimension
193+
// could be one and the corresponding stride could be arbitrary.
192194

193195
func.func @negative_transfer_read_dynamic_dim_to_flatten(
194196
%idx_1: index,
195197
%idx_2: index,
196-
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
198+
%mem: memref<1x4x?x6xi32>) -> vector<1x2x6xi32> {
197199

198200
%c0 = arith.constant 0 : index
199201
%c0_i32 = arith.constant 0 : i32
200202
%res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
201203
in_bounds = [true, true, true]
202-
} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
204+
} : memref<1x4x?x6xi32>, vector<1x2x6xi32>
203205
return %res : vector<1x2x6xi32>
204206
}
205207

@@ -212,6 +214,41 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
212214

213215
// -----
214216

217+
// One of the dims to be flattened is dynamic and leftmost.
218+
219+
func.func @transfer_read_dynamic_leftmost_dim_to_flatten(
220+
%idx_1: index,
221+
%idx_2: index,
222+
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
223+
224+
%c0 = arith.constant 0 : index
225+
%c0_i32 = arith.constant 0 : i32
226+
%res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
227+
in_bounds = [true, true, true]
228+
} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
229+
return %res : vector<1x2x6xi32>
230+
}
231+
232+
// CHECK-LABEL: func.func @transfer_read_dynamic_leftmost_dim_to_flatten
233+
// CHECK-SAME: %[[IDX_1:arg0]]: index
234+
// CHECK-SAME: %[[IDX_2:arg1]]: index
235+
// CHECK-SAME: %[[MEM:arg2]]: memref<1x?x4x6xi32>
236+
// CHECK-NEXT: %[[C0_I32:.+]] = arith.constant 0 : i32
237+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
238+
// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
239+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
240+
// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
241+
// CHECK-NEXT: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
242+
// CHECK-SAME: [%[[C0]], %[[TMP]]], %[[C0_I32]]
243+
// CHECK-SAME: {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
244+
// CHECK-NEXT: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<12xi32> to vector<1x2x6xi32>
245+
// CHECK-NEXT: return %[[RES]] : vector<1x2x6xi32>
246+
247+
// CHECK-128B-LABEL: func @transfer_read_dynamic_leftmost_dim_to_flatten
248+
// CHECK-128B-NOT: memref.collapse_shape
249+
250+
// -----
251+
215252
// The vector to be read represents a _non-contiguous_ slice of the input
216253
// memref.
217254

@@ -451,26 +488,61 @@ func.func @transfer_write_leading_dynamic_dims(
451488

452489
// -----
453490

454-
// One of the dims to be flattened is dynamic - not supported ATM.
491+
// One of the dims to be flattened is dynamic and not leftmost.
455492

456-
func.func @negative_transfer_write_dynamic_to_flatten(
493+
func.func @negative_transfer_write_dynamic_dim_to_flatten(
457494
%idx_1: index,
458495
%idx_2: index,
459496
%vec : vector<1x2x6xi32>,
460-
%mem: memref<1x?x4x6xi32>) {
497+
%mem: memref<1x4x?x6xi32>) {
461498

462499
%c0 = arith.constant 0 : index
463500
%c0_i32 = arith.constant 0 : i32
464501
vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
465-
vector<1x2x6xi32>, memref<1x?x4x6xi32>
502+
vector<1x2x6xi32>, memref<1x4x?x6xi32>
466503
return
467504
}
468505

469-
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
506+
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_dim_to_flatten
470507
// CHECK-NOT: memref.collapse_shape
471508
// CHECK-NOT: vector.shape_cast
472509

473-
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
510+
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_dim_to_flatten
511+
// CHECK-128B-NOT: memref.collapse_shape
512+
513+
// -----
514+
515+
// One of the dims to be flattened is dynamic and leftmost.
516+
517+
func.func @transfer_write_dynamic_leftmost_dim_to_flatten(
518+
%idx_1: index,
519+
%idx_2: index,
520+
%vec : vector<1x2x6xi32>,
521+
%mem: memref<1x?x4x6xi32>) {
522+
523+
%c0 = arith.constant 0 : index
524+
%c0_i32 = arith.constant 0 : i32
525+
vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
526+
vector<1x2x6xi32>, memref<1x?x4x6xi32>
527+
return
528+
}
529+
530+
// CHECK-LABEL: func.func @transfer_write_dynamic_leftmost_dim_to_flatten
531+
// CHECK-SAME: %[[IDX_1:arg0]]: index
532+
// CHECK-SAME: %[[IDX_2:arg1]]: index
533+
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>,
534+
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
535+
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
536+
// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
537+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
538+
// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
539+
// CHECK-NEXT: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
540+
// CHECK-NEXT: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
541+
// CHECK-SAME: [%[[C0]], %[[TMP]]]
542+
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
543+
// CHECK-NEXT: return
544+
545+
// CHECK-128B-LABEL: func @transfer_write_dynamic_leftmost_dim_to_flatten
474546
// CHECK-128B-NOT: memref.collapse_shape
475547

476548
// -----

0 commit comments

Comments
 (0)