diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 6cd3408e2b2e9..1dfe2a57df587 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -40,21 +40,25 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> { let description = [{ The `affine.apply` operation applies an [affine mapping](#affine-maps) to a list of SSA values, yielding a single SSA value. The number of - dimension and symbol arguments to `affine.apply` must be equal to the + dimension and symbol operands to `affine.apply` must be equal to the respective number of dimensional and symbolic inputs to the affine mapping; the affine mapping has to be one-dimensional, and so the `affine.apply` operation always returns one value. The input operands and result must all have ‘index’ type. + An operand that is a valid dimension as per the [rules on valid affine + dimensions and symbols](#restrictions-on-dimensions-and-symbols) + cannot be used as a symbolic operand. + Example: ```mlir - #map10 = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)> + #map = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)> ... - %1 = affine.apply #map10 (%s, %t) + %1 = affine.apply #map (%s, %t) // Inline example. - %2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n] + %2 = affine.apply affine_map<(i)[s0] -> (i + s0)> (%42)[%n] ``` }]; let arguments = (ins AffineMapAttr:$map, Variadic:$mapOperands); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index aa49c49062c76..5d0055993e5fd 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -578,6 +578,15 @@ LogicalResult AffineApplyOp::verify() { if (affineMap.getNumResults() != 1) return emitOpError("mapping must produce one value"); + // Do not allow valid dims to be used in symbol positions. We do allow + // affine.apply to use operands for values that may neither qualify as affine + // dims or affine symbols due to usage outside of affine ops, analyses, etc. + Region *region = getAffineScope(*this); + for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) { + if (::isValidDim(operand, region) && !::isValidSymbol(operand, region)) + return emitError("dimensional operand cannot be used as a symbol"); + } + return success(); } @@ -1359,13 +1368,64 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); *operands = resultOperands; - *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, - oldNumSyms + nextSym); + *mapOrSet = mapOrSet->replaceDimsAndSymbols( + dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym); assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); } +/// A valid affine dimension may appear as a symbol in affine.apply operations. +/// Given an application of `operands` to an affine map or integer set +/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid +/// dims, but not valid symbols into actual dims. Without such a legalization, +/// the affine.apply will be invalid. This method is the exact inverse of +/// canonicalizePromotedSymbols. +template +static void legalizeDemotedDims(MapOrSet &mapOrSet, + SmallVectorImpl &operands) { + if (!mapOrSet || operands.empty()) + return; + + unsigned numOperands = operands.size(); + + assert(mapOrSet->getNumInputs() == numOperands && + "map/set inputs must match number of operands"); + + auto *context = mapOrSet.getContext(); + SmallVector resultOperands; + resultOperands.reserve(numOperands); + SmallVector remappedDims; + remappedDims.reserve(numOperands); + SmallVector symOperands; + symOperands.reserve(mapOrSet.getNumSymbols()); + unsigned nextSym = 0; + unsigned nextDim = 0; + unsigned oldNumDims = mapOrSet.getNumDims(); + SmallVector symRemapping(mapOrSet.getNumSymbols()); + resultOperands.assign(operands.begin(), operands.begin() + oldNumDims); + for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) { + if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) { + // This is a valid dim that appears as a symbol, legalize it. + symRemapping[i - oldNumDims] = + getAffineDimExpr(oldNumDims + nextDim++, context); + remappedDims.push_back(operands[i]); + } else { + symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context); + symOperands.push_back(operands[i]); + } + } + + append_range(resultOperands, remappedDims); + append_range(resultOperands, symOperands); + operands = resultOperands; + mapOrSet = mapOrSet.replaceDimsAndSymbols( + /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym); + + assert(mapOrSet->getNumInputs() == operands.size() && + "map/set inputs must match number of operands"); +} + // Works for either an affine map or an integer set. template static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, @@ -1380,6 +1440,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, "map/set inputs must match number of operands"); canonicalizePromotedSymbols(mapOrSet, operands); + legalizeDemotedDims(*mapOrSet, *operands); // Check to see what dims are used. llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index d39c0c6e41df2..e56079c1cccd4 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) { func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () { // CHECK: affine.for [[I_0_:%.+]] = 0 to 8 { affine.for %arg3 = 0 to 8 { - %1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3] - // CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32> + %1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + // CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32> affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32> } return diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 9bbd19c381163..9703c05fff8f6 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -563,3 +563,17 @@ func.func @no_upper_bound() { } return } + +// ----- + +func.func @invalid_symbol() { + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 26 { + affine.for %arg3 = 0 to 23 { + affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3] + // expected-error@above {{dimensional operand cannot be used as a symbol}} + } + } + } + return +} diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 3160fd9c65c04..a27fbf26e13d8 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16 // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)> // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) { @@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref> // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 { -// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] -// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]] +// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]] +// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]]) // CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32> -// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] +// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]] // CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]] -// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] +// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) // CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]] -// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) +// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) // CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]] -// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) // CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32> // ----- @@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, // ----- // CHECK-LABEL: func @fold_store_keep_nontemporal( -// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> +// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>