Skip to content

Commit 403a8a1

Browse files
authored
add collapse shape pass for linalg{broadcast,transpose,fill,reduce} (#302)
linalg.generic with broadcast before ``` %79 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%76 : tensor<1x1x1x1x1x1x1x1x2x2xi32>) outs(%77 : tensor<2x2x2x2x2x2x2x2x2x2xi32>) attrs = {broadcastDims = array<i64: 0, 1, 2, 3, 4, 5, 6, 7>} { ^bb0(%in: i32, %out: i32): linalg.yield %in : i32 } -> tensor<2x2x2x2x2x2x2x2x2x2xi32> ``` after ``` %collapsed_30 = tensor.collapse_shape %89 [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9]] : tensor<1x1x1x1x1x1x1x1x2x2xi32> into tensor<1x4xi32> %92 = tensor.collapse_shape %77 [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9]] : tensor<2x2x2x2x2x2x2x2x2x2xi32> into tensor<256x4xi32> %93 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_30 : tensor<1x4xi32>) outs(%92 : tensor<256x4xi32>) attrs = {broadcastDims = array<i64: 0>} { ^bb0(%in: i32, %out: i32): linalg.yield %in : i32 } -> tensor<256x4xi32> %expanded_31 = tensor.expand_shape %93 [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9]] output_shape [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : tensor<256x4xi32> into tensor<2x2x2x2x2x2x2x2x2x2xi32> ``` linalg.transpose before ``` %transposed = linalg.transpose ins(%expanded_2 : tensor<2x2x2x2x2x2x2x2x2x2xi64>) outs(%66 : tensor<2x2x2x2x2x2x2x2x2x2xi64>) permutation = [0, 1, 2, 3, 4, 5, 6, 7, 9, 8] ``` after ``` %collapsed = tensor.collapse_shape %expanded_17 [[0, 1, 2, 3, 4, 5, 6, 7], [8], [9]] : tensor<2x2x2x2x2x2x2x2x2x2xi64> into tensor<256x2x2xi64> %77 = tensor.collapse_shape %66 ...: %transposed = linalg.transpose ins(%collapsed : tensor<256x2x2xi64>) outs(%77 : tensor<256x2x2xi64>) permutation = [0, 2, 1] %expanded_22 = tensor.expand_shape %transposed ... ``` linalg.fill before ``` %13 = linalg.fill ins(%c1_i32 : i32) outs(%12 : tensor<1x1x1x1x1x2x1xi32>) -> tensor<1x1x1x1x1x2x1xi32> ``` after ``` %17 = tensor.collapse_shape %12 ... %18 = linalg.fill ins(%c1_i32 : i32) outs(%17 : tensor<2xi32>) -> tensor<2xi32> %expanded_6 = tensor.expand_shape %18 [[0, 1, 2, 3, 4, 5, 6]] output_shape [1, 1, 1, 1, 1, 2, 1] : tensor<2xi32> into tensor<1x1x1x1x1x2x1xi32> ``` linalg.reduce before ``` %reduced = linalg.reduce ins(%transposed : tensor<2x2x2x2x2x2x2x2x2x2xi64>) outs(%68 : tensor<2x2x2x2x2x2x2x2x2xi64>) dimensions = [8] (%in: i64, %init: i64) { %311 = arith.xori %in, %init : i64 linalg.yield %311 : i64 } ``` after ``` %collapsed_20 = tensor.collapse_shape %expanded_19 [[0, 1, 2, 3, 4, 5, 6, 7], [8]] : tensor<2x2x2x2x2x2x2x2x2xi64> into tensor<256x2xi64> %reduced = linalg.reduce ins(%transposed : tensor<256x2x2xi64>) outs(%collapsed_20 : tensor<256x2xi64>) dimensions = [1] (%in: i64, %init: i64) { %377 = arith.xori %in, %init : i64 linalg.yield %377 : i64 } %expanded_21 = tensor.expand_shape %reduced [[0, 1, 2, 3, 4, 5, 6, 7], [8]] output_shape [2, 2, 2, 2, 2, 2, 2, 2, 2] : tensor<256x2xi64> into tensor<2x2x2x2x2x2x2x2x2xi64> ```
1 parent e7a375a commit 403a8a1

File tree

10 files changed

+724
-51
lines changed

10 files changed

+724
-51
lines changed

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "triton-shared/Analysis/PtrAnalysis.h"
1414
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
1515
#include "triton-shared/Utils/Utils.h"
16+
#include "triton-shared/Conversion/TritonArithToLinalg/ConversionTools.h"
1617

1718
#include "triton/Dialect/Triton/IR/Dialect.h"
1819

@@ -109,10 +110,6 @@ static Value getScalarValue(Value operand, Location loc,
109110
return nullptr;
110111
}
111112

112-
static SmallVector<utils::IteratorType> getNParallelLoopsAttrs(unsigned n) {
113-
return SmallVector<utils::IteratorType>(n, utils::IteratorType::parallel);
114-
}
115-
116113
// if order is empty, transpose the last two dimensions
117114
// otherwise, use the provided order.
118115
// The order must be a permutation of the source rank.
@@ -656,49 +653,6 @@ struct BroadcastConverter : public OpConversionPattern<triton::BroadcastOp> {
656653
private:
657654
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
658655

659-
SmallVector<int64_t> getBroadcastDims(RankedTensorType src,
660-
RankedTensorType dst) const {
661-
SmallVector<int64_t> broadcastDims;
662-
auto srcShape = src.getShape();
663-
auto dstShape = dst.getShape();
664-
665-
for (size_t i = 0; i < srcShape.size(); i++) {
666-
if (dstShape[i] != srcShape[i]) {
667-
assert(srcShape[i] == 1);
668-
broadcastDims.push_back(i);
669-
}
670-
}
671-
assert(!broadcastDims.empty() && "cannot identify broadcast dimension");
672-
return broadcastDims;
673-
}
674-
675-
// Broadcasts input tensor based on TosaToLinalg's broadcastToShape
676-
AffineMap getBroadcastAffineMap(MLIRContext *context,
677-
ArrayRef<int64_t> inputShape,
678-
ArrayRef<int64_t> broadcastToShape) const {
679-
680-
assert(broadcastToShape.size() >= inputShape.size());
681-
682-
// Create affine map and shapes for tensor initialization.
683-
SmallVector<AffineExpr> outExpr;
684-
685-
size_t diff = broadcastToShape.size() - inputShape.size();
686-
for (size_t i = 0; i < broadcastToShape.size(); i++) {
687-
if (i < diff) {
688-
continue;
689-
}
690-
size_t j = i - diff;
691-
if (inputShape[j] == 1) {
692-
// Broadcast singleton dimension
693-
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
694-
continue;
695-
}
696-
// Non-broadcast case
697-
outExpr.push_back(mlir::getAffineDimExpr(i, context));
698-
}
699-
return AffineMap::get(broadcastToShape.size(), 0, outExpr, context);
700-
}
701-
702656
public:
703657
LogicalResult
704658
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef TRITON_CONVERSION_TRITONARITHTOLINALG_CONVERSIONTOOLS_H
2+
#define TRITON_CONVERSION_TRITONARITHTOLINALG_CONVERSIONTOOLS_H
3+
4+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
static inline SmallVector<utils::IteratorType> getNParallelLoopsAttrs(unsigned n) {
10+
return SmallVector<utils::IteratorType>(n, utils::IteratorType::parallel);
11+
}
12+
13+
static inline SmallVector<int64_t> getBroadcastDims(RankedTensorType src,
14+
RankedTensorType dst) {
15+
SmallVector<int64_t> broadcastDims;
16+
auto srcShape = src.getShape();
17+
auto dstShape = dst.getShape();
18+
19+
for (size_t i = 0; i < srcShape.size(); i++) {
20+
if (dstShape[i] != srcShape[i]) {
21+
assert(srcShape[i] == 1);
22+
broadcastDims.push_back(i);
23+
}
24+
}
25+
assert(!broadcastDims.empty() && "cannot identify broadcast dimension");
26+
return broadcastDims;
27+
}
28+
29+
// Broadcasts input tensor based on TosaToLinalg's broadcastToShape
30+
static inline AffineMap
31+
getBroadcastAffineMap(MLIRContext *context, ArrayRef<int64_t> inputShape,
32+
ArrayRef<int64_t> broadcastToShape) {
33+
34+
assert(broadcastToShape.size() >= inputShape.size());
35+
36+
// Create affine map and shapes for tensor initialization.
37+
SmallVector<AffineExpr> outExpr;
38+
39+
size_t diff = broadcastToShape.size() - inputShape.size();
40+
for (size_t i = 0; i < broadcastToShape.size(); i++) {
41+
if (i < diff) {
42+
continue;
43+
}
44+
size_t j = i - diff;
45+
if (inputShape[j] == 1) {
46+
// Broadcast singleton dimension
47+
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
48+
continue;
49+
}
50+
// Non-broadcast case
51+
outExpr.push_back(mlir::getAffineDimExpr(i, context));
52+
}
53+
return AffineMap::get(broadcastToShape.size(), 0, outExpr, context);
54+
}
55+
56+
} // namespace triton
57+
} // namespace mlir
58+
59+
#endif // TRITON_CONVERSION_TRITONARITHTOLINALG_CONVERSIONTOOLS_H
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Copyright (c) Microsoft Corporation.
4+
// Licensed under the MIT license.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef TRITON_CONVERSION_TRITONTOLINALG_CollapseShape_H
9+
#define TRITON_CONVERSION_TRITONTOLINALG_CollapseShape_H
10+
11+
#include "mlir/IR/BuiltinOps.h"
12+
#include "mlir/Pass/Pass.h"
13+
14+
namespace mlir {
15+
namespace triton {
16+
17+
std::unique_ptr<OperationPass<ModuleOp>> createCollapseShapePass();
18+
19+
} // namespace triton
20+
} // namespace mlir
21+
22+
#endif // TRITON_CONVERSION_TRITONTOLINALG_CollapseShape_H

include/triton-shared/Conversion/TritonToLinalgExperimental/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimental.h"
1212
#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
1313
#include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToPtr.h"
14+
#include "triton-shared/Conversion/TritonToLinalgExperimental/CollapseShape.h"
1415

1516
namespace mlir {
1617
namespace triton {

include/triton-shared/Conversion/TritonToLinalgExperimental/Passes.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def TritonToLinalgExperimental : Pass<"triton-to-linalg-experimental", "mlir::Mo
1515
let constructor = "triton::createTritonToLinalgExperimentalPass()";
1616
let options = [
1717
Option<"enableMakeGatherScatterTensorPtr", "enable-make-gather-scatter", "bool", /*default*/"true",
18-
"Enable make_gather_scatter_tptr support">
18+
"Enable make_gather_scatter_tptr support">,
19+
Option<"enableCollapseShape", "enable-collapse-shape", "bool", /*default*/"false",
20+
"Enable collapse shape pass">,
1921
];
2022
}
2123

@@ -29,4 +31,8 @@ def TritonToPtr : Pass<"triton-to-ptr", "mlir::ModuleOp"> {
2931
let constructor = "triton::createTritonToPtrPass()";
3032
}
3133

34+
def CollapseShape : Pass</*cli-arg*/"collapse-shape", /*Op*/"mlir::ModuleOp"> {
35+
let summary = "Compress tensor dimensions to improve linalg{broadcast,transpose,fill,reduce} efficiency";
36+
let constructor = "triton::createCollapseShapePass()";
37+
}
3238
#endif

lib/Conversion/TritonToLinalgExperimental/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_triton_library(TritonToLinalgExperimental
88
TritonToLinalgExperimentalPass.cpp
99
ReconcilePtrCastsPass.cpp
1010
TritonToPtrPass.cpp
11+
CollapseShape.cpp
1112

1213
DEPENDS
1314
TritonToLinalgExperimentalConversionPassIncGen

0 commit comments

Comments
 (0)