Skip to content

Commit 065d0c0

Browse files
authored
Merge pull request #537 from Xilinx/corentin.fix_affine_if
Consider linalg.index an affine dim
2 parents ee712a6 + b9293b0 commit 065d0c0

File tree

14 files changed

+123
-28
lines changed

14 files changed

+123
-28
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
1616

1717
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
18+
#include "mlir/Dialect/Affine/IR/AffineTraits.h"
1819
#include "mlir/Dialect/Arith/IR/Arith.h"
1920
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/IR/AffineMap.h"

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef AFFINE_OPS
1414
#define AFFINE_OPS
1515

16+
include "mlir/Dialect/Affine/IR/AffineTraits.td"
1617
include "mlir/Dialect/Arith/IR/ArithBase.td"
1718
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
1819
include "mlir/Interfaces/ControlFlowInterfaces.td"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- AffineTraits.h - MLIR Affine Traits --------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines traits brought in by the Affine dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef AFFINE_TRAITS_H
13+
#define AFFINE_TRAITS_H
14+
15+
#include "mlir/IR/OpDefinition.h"
16+
17+
namespace mlir::OpTrait {
18+
19+
template <typename ConcreteType>
20+
class AffineDim : public TraitBase<ConcreteType, AffineDim> {
21+
public:
22+
static LogicalResult verifyTrait(Operation *op) { return success(); }
23+
};
24+
25+
} // namespace mlir::OpTrait
26+
27+
#endif // AFFINE_TRAITS_H
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- AffineTraits.td - Affine dialect traits -------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines traits brought in by the MLIR Affine dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef AFFINE_TRAITS
13+
#define AFFINE_TRAITS
14+
15+
include "mlir/IR/OpBase.td"
16+
17+
// Trait to declare that an op result is an affine dimension identifier.
18+
// Prevents the result from being seen as a symbol into AffineMaps
19+
// and IntegerSets.
20+
// This is a deviation from upstream to consider linalg.index as
21+
// a dimension rather than a symbol. See this PR:
22+
// https://github.com/Xilinx/llvm-project/pull/537
23+
def AffineDim : NativeOpTrait<"AffineDim">;
24+
25+
#endif // AFFINE_TRAITS

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_IR_LINALG_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Affine/IR/AffineTraits.h"
1314
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1516
#include "mlir/IR/AffineExpr.h"

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LINALG_OPS
1414
#define LINALG_OPS
1515

16+
include "mlir/Dialect/Affine/IR/AffineTraits.td"
1617
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1718
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1819
include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -46,7 +47,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [Pure, ReturnLike, Terminator]>,
4647
let hasVerifier = 1;
4748
}
4849

49-
def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
50+
def Linalg_IndexOp : Linalg_Op<"index", [Pure, AffineDim]>,
5051
Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>,
5152
Results<(outs Index:$result)> {
5253
let summary = "linalg index operation";

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1112
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1213
#include "mlir/Dialect/UB/IR/UBOps.h"
1314
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -312,6 +313,10 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
312313
return isa<AffineForOp, AffineParallelOp>(parentOp);
313314
}
314315

316+
// Remove me: linalg.index ops are valid affine dim identifiers
317+
if (op->hasTrait<OpTrait::AffineDim>())
318+
return true;
319+
315320
// Affine apply operation is ok if all of its operands are ok.
316321
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
317322
return applyOp.isValidDim(region);
@@ -439,6 +444,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
439444
return false;
440445
}
441446

447+
// Remove me: linalg.index ops are not valid affine symbols
448+
if (defOp->hasTrait<OpTrait::AffineDim>())
449+
return false;
450+
442451
// Constant operation is ok.
443452
Attribute operandCst;
444453
if (matchPattern(defOp, m_Constant(&operandCst)))
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mlir-opt -canonicalize %s | FileCheck %s
2+
3+
// Check that linalg.index does not cause folding of affine.if set to
4+
// a symbolic set.
5+
// This is a deviation from upstream MLIR.
6+
// The origin of this test is that PR:
7+
// https://github.com/Xilinx/llvm-project/pull/537
8+
9+
// CHECK: = affine_set<(d0) : (-d0 + 5 >= 0)>
10+
#set = affine_set<(d0) : (-d0 + 5 >= 0)>
11+
12+
func.func @linalg_index_affine_if(%in: tensor<10xf32>) -> tensor<10xf32> {
13+
%empty = tensor.empty() : tensor<10xf32>
14+
%out = linalg.generic {
15+
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
16+
iterator_types = ["parallel"]}
17+
ins(%in : tensor<10xf32>)
18+
outs(%empty : tensor<10xf32>) {
19+
^bb0(%a: f32, %b: f32):
20+
%c42f = arith.constant 42.0 : f32
21+
%i = linalg.index 0 : index
22+
%ret = affine.if #set(%i) -> f32 {
23+
affine.yield %a : f32
24+
} else {
25+
affine.yield %c42f : f32
26+
}
27+
linalg.yield %ret : f32
28+
} -> tensor<10xf32>
29+
return %out : tensor<10xf32>
30+
}

mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ module attributes {transform.with_named_sequence} {
4040
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
4141

4242
// Compute input channel/convolved indices.
43-
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
44-
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
45-
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
43+
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
44+
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
45+
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
4646

4747
// Extract from the input tensor.
4848
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
@@ -227,9 +227,9 @@ module attributes {transform.with_named_sequence} {
227227
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
228228

229229
// Im2col maps
230-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)>
231-
// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)>
232-
// CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)>
230+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 9)>
231+
// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d0 floordiv 14 + (d1 mod 9) floordiv 3)>
232+
// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 - (d0 floordiv 14) * 14 - (d1 floordiv 3) * 3)>
233233

234234

235235
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
@@ -251,9 +251,9 @@ module attributes {transform.with_named_sequence} {
251251
// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index
252252

253253
// Compute input channel/convolved indices.
254-
// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]]
255-
// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]]
256-
// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]]
254+
// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]](%[[KINDEX]])
255+
// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]](%[[NINDEX]], %[[KINDEX]])
256+
// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]](%[[NINDEX]], %[[KINDEX]])
257257

258258
// Extract from the input tensor.
259259
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
@@ -300,9 +300,9 @@ module attributes {transform.with_named_sequence} {
300300
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
301301

302302
// Compute input channel/convolved indices.
303-
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
304-
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
305-
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
303+
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
304+
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
305+
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
306306

307307
// Extract from the input tensor.
308308
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,13 +986,13 @@ module {
986986
#map3 = affine_map<(d0, d1) -> (d0, d1)>
987987

988988
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
989-
// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<()[s0] -> (s0 floordiv 4)>
989+
// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0 floordiv 4)>
990990

991991
func.func @fuse_and_collapse(%arg0: tensor<3x4xindex>) -> tensor<2x12xindex> {
992992
%1 = tensor.empty() : tensor<2x3x4xindex>
993993
// CHECK: linalg.generic {
994994
// CHECK: %[[INDEX1:[a-zA-Z0-9_]+]] = linalg.index 1 : index
995-
// CHECK-NEXT: %[[MAP:[a-zA-Z0-9_]+]] = affine.apply #map1()[%[[INDEX1]]]
995+
// CHECK-NEXT: %[[MAP:[a-zA-Z0-9_]+]] = affine.apply #map1(%[[INDEX1]])
996996
// CHECK-NEXT: linalg.yield %[[MAP]] : index
997997
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0: tensor<3x4xindex>) outs(%1 : tensor<2x3x4xindex>) {
998998
^bb0(%in: index, %out: index):

0 commit comments

Comments
 (0)