Skip to content

Commit 1d2db1a

Browse files
authored
[LinalgExt] Add simple vectorization for map_scatter (#21090)
Allows vector types for the map_scatter input operand, and adds a trivial vectorization pass, which converts the input of map_scatter ops to vector operands with a vector.transfer_read. This vectorization lowering is not complete, because the map_scatter op in this form will not be able to lower into other operations. The final lowering will happen as a decomposition after the map_scatter op has been both vectorized and bufferized, and the map_scatter op will be decomposed into a vector.scatter op. The reason this decomposition will need to happen after bufferization, is that the vector.scatter op requires a memref base. The decomposition pattern of the vectorized map_scatter will come as a follow-up PR. Eventually, the vectorized map_scatter op could instead be decomposed into an `iree_vector_ext.transfer_scatter` op, but since the op does not exist yet, the current plan is to use `vector.scatter` first. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent a6bbf4a commit 1d2db1a

File tree

11 files changed

+160
-1
lines changed

11 files changed

+160
-1
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,9 @@ LogicalResult MapScatterOp::verify() {
468468
if (getInputType().getElementType() != getOutputType().getElementType()) {
469469
return emitOpError("expected input and output element types to match");
470470
}
471+
if (getInputType().getRank() == 0) {
472+
return emitOpError("expected input type to have non-zero rank");
473+
}
471474
Region &transformRegion = getTransformationRegion();
472475
Block &transformBody = transformRegion.getBlocks().front();
473476
if (transformBody.getNumArguments() != getInputRank()) {

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def IREELinalgExt_MapScatterOp : IREELinalgExt_PureOp<"map_scatter",
348348
the mask value is true, the input value will be written.
349349
}];
350350
let arguments = (ins
351-
AnyRankedTensorOrMemRefType:$input,
351+
AnyShaped:$input,
352352
AnyRankedTensorOrMemRefType:$output
353353
);
354354
let results = (outs Variadic<AnyRankedTensor>:$results);
@@ -374,6 +374,11 @@ def IREELinalgExt_MapScatterOp : IREELinalgExt_PureOp<"map_scatter",
374374
return getOutputType().getRank();
375375
}
376376

377+
// Return true if the map_scatter op has vector semantics.
378+
bool isVectorized() {
379+
return isa<VectorType>(getInputType());
380+
}
381+
377382
// Helper to apply transformations to the source index block arguments of
378383
// the transformation body, and replace the uses of the previous source
379384
// indices with the values returned by `transformationBuilder`. The argument

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,21 @@ func.func @map_scatter_wrong_num_yielded_values(
568568

569569
// -----
570570

571+
func.func @map_scatter_0D(
572+
%input: vector<f32>, %output: memref<4xf32>
573+
) {
574+
// expected-error@+1 {{expected input type to have non-zero rank}}
575+
iree_linalg_ext.map_scatter %input into %output {
576+
^bb0():
577+
%mask = arith.constant true
578+
%zero = arith.constant 0 : index
579+
iree_linalg_ext.yield %zero, %mask : index, i1
580+
} : vector<f32> into memref<4xf32>
581+
return
582+
}
583+
584+
// -----
585+
571586
func.func @arg_compare_invalid_too_many_inputs(
572587
%input_val: tensor<2x10xf32>,
573588
%input_extra: tensor<2x10xf32>,

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,28 @@ func.func @map_scatter_memref_static(
833833

834834
// -----
835835

836+
func.func @map_scatter_vector(
837+
%input: vector<4x16x64xf32>, %output: tensor<4x16x64xf32>
838+
) -> tensor<4x16x64xf32> {
839+
%0 = iree_linalg_ext.map_scatter %input into %output {
840+
^bb0(%idx0: index, %idx1: index, %idx2: index):
841+
%mask = arith.constant true
842+
iree_linalg_ext.yield %idx0, %idx1, %idx2, %mask : index, index, index, i1
843+
} : vector<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32>
844+
return %0 : tensor<4x16x64xf32>
845+
}
846+
// CHECK-LABEL: func.func @map_scatter_vector(
847+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
848+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
849+
// CHECK: %[[RES:.+]] = iree_linalg_ext.map_scatter %[[INPUT]] into %[[OUTPUT]] {
850+
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
851+
// CHECK: %[[MASK:.+]] = arith.constant true
852+
// CHECK: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], %[[IDX2]], %[[MASK]]
853+
// CHECK: } : vector<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32>
854+
// CHECK: return %[[RES]] : tensor<4x16x64xf32>
855+
856+
// -----
857+
836858
func.func @arg_compare_static(
837859
%input : tensor<2x6xf32>,
838860
%outv : tensor<2xf32>, %outi : tensor<2xindex>

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ iree_compiler_cc_library(
4545
"TestReshapeFusion.cpp",
4646
"TileAttention.cpp",
4747
"TransposeFusion.cpp",
48+
"VectorizeIREELinalgExtOps.cpp",
4849
],
4950
hdrs = [
5051
"Passes.h",
@@ -85,5 +86,6 @@ iree_compiler_cc_library(
8586
"@llvm-project//mlir:TransformDialect",
8687
"@llvm-project//mlir:TransformUtils",
8788
"@llvm-project//mlir:Transforms",
89+
"@llvm-project//mlir:VectorDialect",
8890
],
8991
)

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ iree_cc_library(
4242
"TestReshapeFusion.cpp"
4343
"TileAttention.cpp"
4444
"TransposeFusion.cpp"
45+
"VectorizeIREELinalgExtOps.cpp"
4546
DEPS
4647
::PassesIncGen
4748
LLVMSupport
@@ -72,6 +73,7 @@ iree_cc_library(
7273
MLIRTransformDialect
7374
MLIRTransformUtils
7475
MLIRTransforms
76+
MLIRVectorDialect
7577
iree::compiler::Dialect::LinalgExt::IR
7678
iree::compiler::Dialect::LinalgExt::Utils
7779
iree::compiler::Dialect::Util::IR

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,13 @@ def TestReshapeFusionPass :
108108
let summary = "Test reshape fusion patterns";
109109
}
110110

111+
def VectorizeIREELinalgExtOpsPass :
112+
InterfacePass<"iree-linalg-ext-vectorize-ops", "mlir::FunctionOpInterface"> {
113+
let summary = "Convert linalg_ext ops into their vector form.";
114+
let dependentDialects = [
115+
"::mlir::vector::VectorDialect",
116+
"::mlir::arith::ArithDialect"
117+
];
118+
}
119+
111120
#endif // IREE_DIALECT_LINALGEXT_PASSES
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
8+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
9+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
namespace mlir::iree_compiler::IREE::LinalgExt {
15+
16+
#define GEN_PASS_DEF_VECTORIZEIREELINALGEXTOPSPASS
17+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
18+
19+
namespace {
20+
21+
struct VectorizeStaticMapScatterOpPattern final
22+
: OpRewritePattern<IREE::LinalgExt::MapScatterOp> {
23+
using OpRewritePattern<IREE::LinalgExt::MapScatterOp>::OpRewritePattern;
24+
LogicalResult matchAndRewrite(IREE::LinalgExt::MapScatterOp mapScatterOp,
25+
PatternRewriter &rewriter) const override {
26+
if (mapScatterOp.isVectorized()) {
27+
return rewriter.notifyMatchFailure(mapScatterOp,
28+
"map_scatter is already vectorized");
29+
}
30+
ShapedType inputType = mapScatterOp.getInputType();
31+
if (!inputType.hasStaticShape()) {
32+
return rewriter.notifyMatchFailure(mapScatterOp,
33+
"map_scatter has non-static shape");
34+
}
35+
Location loc = mapScatterOp.getLoc();
36+
rewriter.setInsertionPoint(mapScatterOp);
37+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
38+
SmallVector<Value> zeros(inputType.getRank(), zero);
39+
auto inputVectorType =
40+
VectorType::get(inputType.getShape(), inputType.getElementType());
41+
Value inputVector = rewriter.create<vector::TransferReadOp>(
42+
loc, inputVectorType, mapScatterOp.getInput(), /*indices=*/zeros);
43+
auto vectorizedMapScatterOp =
44+
clone(rewriter, mapScatterOp, mapScatterOp.getResultTypes(),
45+
{inputVector, mapScatterOp.getOutput()});
46+
rewriter.replaceOp(mapScatterOp, vectorizedMapScatterOp);
47+
return success();
48+
}
49+
};
50+
51+
struct VectorizeIREELinalgExtOpsPass final
52+
: impl::VectorizeIREELinalgExtOpsPassBase<VectorizeIREELinalgExtOpsPass> {
53+
void runOnOperation() {
54+
MLIRContext *context = &getContext();
55+
RewritePatternSet patterns(context);
56+
patterns.add<VectorizeStaticMapScatterOpPattern>(context);
57+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
58+
return signalPassFailure();
59+
}
60+
}
61+
};
62+
} // namespace
63+
64+
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ iree_lit_test_suite(
2828
"reshape_fusion.mlir",
2929
"split_reduction.mlir",
3030
"tiling.mlir",
31+
"vectorize_iree_linalg_ext_ops.mlir",
3132
],
3233
include = ["*.mlir"],
3334
),

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ iree_lit_test_suite(
2626
"reshape_fusion.mlir"
2727
"split_reduction.mlir"
2828
"tiling.mlir"
29+
"vectorize_iree_linalg_ext_ops.mlir"
2930
TOOLS
3031
FileCheck
3132
iree-opt

0 commit comments

Comments
 (0)