Skip to content

Commit d05e6a7

Browse files
committed
[MLIR][Affine] Fix affine.apply verifier and add functionality to demote invalid symbols to dims
Fixes: #120189, #128403 Fix affine.apply verifier to reject symbolic operands that are valid dims for affine purposes. This doesn't affect other users in other context where the operands were neither valid dims or symbols (for eg. in scf.for or other region ops). Otherwise, it was possible for `-canonicalize` to have generated invalid IR when such affine.apply ops were composed. Introduce a method to demote a symoblic operand to a dimensional one (the inverse of the current canonicalizePromotedSymbols). Demote operands that could/should have been valid affine dimensional values (affine loop IVs or their functions) from symbols to dims. This is a general method that can be used to legalize a map + operands post construction depending on its operands. Use it during `canonicalizeMapOrSetAndOperands` so that pattern rewriter-based passes are able to generate valid IR post folding. Users outside of affine analyses/dialects remain unaffected. In some cases, this change also leads to better simplified operands, duplicates eliminated as shown in one of the test cases where the same operand appeared as a symbol and as a dim. This PR also fixes test cases where dimensional positions should have been ideally used with affine.apply (for affine loop IVs for example).
1 parent d403f33 commit d05e6a7

File tree

5 files changed

+99
-26
lines changed

5 files changed

+99
-26
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,15 @@ LogicalResult AffineApplyOp::verify() {
568568
if (affineMap.getNumResults() != 1)
569569
return emitOpError("mapping must produce one value");
570570

571+
// Do not allow valid dims to be used in symbol positions. We do allow
572+
// affine.apply to use operands for values that may neither qualify as affine
573+
// dims or affine symbols due to usage outside of affine ops, analyses, etc.
574+
Region *region = getAffineScope(*this);
575+
for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
576+
if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
577+
return emitError("dimensional operand cannot be used as a symbol");
578+
}
579+
571580
return success();
572581
}
573582

@@ -1351,13 +1360,62 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
13511360

13521361
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
13531362
*operands = resultOperands;
1354-
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1355-
oldNumSyms + nextSym);
1363+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1364+
dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
13561365

13571366
assert(mapOrSet->getNumInputs() == operands->size() &&
13581367
"map/set inputs must match number of operands");
13591368
}
13601369

1370+
// A valid affine dimension may appear as a symbol in affine.apply operations.
1371+
// This function canonicalizes symbols that are valid dims, but not valid
1372+
// symbols into actual dims. Without such a legalization, the affine.apply will
1373+
// be invalid. This method is the exact inverse of canonicalizePromotedSymbols.
1374+
template <class MapOrSet>
1375+
static void legalizeDemotedDims(MapOrSet *mapOrSet,
1376+
SmallVectorImpl<Value> &operands) {
1377+
if (!mapOrSet || operands.empty())
1378+
return;
1379+
1380+
assert(mapOrSet->getNumInputs() == operands.size() &&
1381+
"map/set inputs must match number of operands");
1382+
1383+
auto *context = mapOrSet->getContext();
1384+
SmallVector<Value, 8> resultOperands;
1385+
resultOperands.reserve(operands.size());
1386+
SmallVector<Value, 8> remappedDims;
1387+
remappedDims.reserve(operands.size());
1388+
unsigned nextSym = 0;
1389+
unsigned nextDim = 0;
1390+
unsigned oldNumDims = mapOrSet->getNumDims();
1391+
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1392+
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1393+
if (i >= oldNumDims) {
1394+
if (operands[i] && isValidDim(operands[i]) &&
1395+
!isValidSymbol(operands[i])) {
1396+
// This is a valid dim that appears as a symbol, legalize it.
1397+
symRemapping[i - oldNumDims] =
1398+
getAffineDimExpr(oldNumDims + nextDim++, context);
1399+
remappedDims.push_back(operands[i]);
1400+
} else {
1401+
symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1402+
resultOperands.push_back(operands[i]);
1403+
}
1404+
} else {
1405+
resultOperands.push_back(operands[i]);
1406+
}
1407+
}
1408+
1409+
resultOperands.insert(resultOperands.begin() + oldNumDims,
1410+
remappedDims.begin(), remappedDims.end());
1411+
operands = resultOperands;
1412+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1413+
/*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1414+
1415+
assert(mapOrSet->getNumInputs() == operands.size() &&
1416+
"map/set inputs must match number of operands");
1417+
}
1418+
13611419
// Works for either an affine map or an integer set.
13621420
template <class MapOrSet>
13631421
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
@@ -1372,6 +1430,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
13721430
"map/set inputs must match number of operands");
13731431

13741432
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1433+
legalizeDemotedDims<MapOrSet>(mapOrSet, *operands);
13751434

13761435
// Check to see what dims are used.
13771436
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
14601460
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14611461
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
14621462
affine.for %arg3 = 0 to 8 {
1463-
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
1464-
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
1463+
%1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
1464+
// CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
14651465
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
14661466
}
14671467
return

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,17 @@ func.func @no_upper_bound() {
563563
}
564564
return
565565
}
566+
567+
// -----
568+
569+
func.func @invalid_symbol() {
570+
affine.for %arg1 = 0 to 1 {
571+
affine.for %arg2 = 0 to 26 {
572+
affine.for %arg3 = 0 to 23 {
573+
affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
574+
// expected-error@above {{dimensional operand cannot be used as a symbol}}
575+
}
576+
}
577+
}
578+
return
579+
}

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
361361

362362
// -----
363363

364-
#map = affine_map<()[s0] -> (s0 + 5)>
365-
#map1 = affine_map<()[s0] -> (s0 + 17)>
364+
#map = affine_map<(d0) -> (d0 + 5)>
365+
#map1 = affine_map<(d0) -> (d0 + 17)>
366366

367367
// Test with non-int/float memref types.
368368

@@ -383,8 +383,8 @@ func.func @memref_index_type() {
383383
}
384384
affine.for %arg3 = 0 to 3 {
385385
%4 = affine.load %alloc_2[%arg3] : memref<3xindex>
386-
%5 = affine.apply #map()[%4]
387-
%6 = affine.apply #map1()[%3]
386+
%5 = affine.apply #map(%4)
387+
%6 = affine.apply #map1(%3)
388388
%7 = memref.load %alloc[%5, %6] : memref<8x18xf32>
389389
affine.store %7, %alloc_1[%arg3] : memref<3xf32>
390390
}

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
496496

497497
// -----
498498

499-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
501501
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502502
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503503
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
518518
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519519
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520520
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521-
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522-
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
523523
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524-
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
525525
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526526

527527
// -----
528528

529-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
531531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
532532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
533533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
549549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
550550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
551551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
552-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
554554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
555555

556556
// -----
557557

558-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
560560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
561561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
562562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
578578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
579579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
580580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
581-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
583583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
584584

585585
// -----
586586

587-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
589589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
590590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
591591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
608608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
609609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
610610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
611-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
613613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
614614

615615
// -----
@@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
678678
// -----
679679

680680
// CHECK-LABEL: func @fold_store_keep_nontemporal(
681-
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
681+
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
682682
func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
683683
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
684684
memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>

0 commit comments

Comments
 (0)