Skip to content

Commit 610c8c6

Browse files
committed
[mlir][VectorToLLVM] Add support for unrolling and lowering multi-dimensional vector.scatter operations
1 parent 8a0d145 commit 610c8c6

File tree

7 files changed

+158
-17
lines changed

7 files changed

+158
-17
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
239239
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
240240
PatternBenefit benefit = 1);
241241

242+
/// Populate the pattern set with the following patterns:
243+
///
244+
/// [UnrollScatter]
245+
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
246+
/// outermost dimension.
247+
void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns,
248+
PatternBenefit benefit = 1);
249+
242250
/// Populate the pattern set with the following patterns:
243251
///
244252
/// [UnrollGather]

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ using UnrollVectorOpFn =
254254
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
255255

256256
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
257-
UnrollVectorOpFn unrollFn);
257+
UnrollVectorOpFn unrollFn,
258+
VectorType vectorTy = nullptr);
258259

259260
/// Generic utility for unrolling values of type vector<NxAxBx...>
260261
/// to N values of type vector<AxBx...> using vector.extract. If the input

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9494
populateVectorStepLoweringPatterns(patterns);
9595
populateVectorRankReducingFMAPattern(patterns);
9696
populateVectorGatherLoweringPatterns(patterns);
97+
populateVectorScatterLoweringPatterns(patterns);
9798
populateVectorFromElementsUnrollPatterns(patterns);
9899
populateVectorToElementsUnrollPatterns(patterns);
99100
if (armI8MM) {

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
88
LowerVectorMask.cpp
99
LowerVectorMultiReduction.cpp
1010
LowerVectorScan.cpp
11+
LowerVectorScatter.cpp
1112
LowerVectorShapeCast.cpp
1213
LowerVectorShuffle.cpp
1314
LowerVectorStep.cpp
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.scatter' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/SCF/IR/SCF.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
23+
#include "mlir/IR/BuiltinTypes.h"
24+
#include "mlir/IR/Location.h"
25+
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/IR/TypeUtilities.h"
27+
28+
#define DEBUG_TYPE "vector-scatter-lowering"
29+
30+
using namespace mlir;
31+
using namespace mlir::vector;
32+
33+
namespace {
34+
35+
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
36+
/// outermost dimension. For example:
37+
/// ```
38+
/// vector.scatter %base[%c0][%idx], %mask, %value :
39+
/// memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
40+
///
41+
/// ==>
42+
///
43+
/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32>
44+
/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1>
45+
/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32>
46+
/// vector.scatter %base[%c0][%i0], %m0, %v0 :
47+
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
48+
///
49+
/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32>
50+
/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1>
51+
/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32>
52+
/// vector.scatter %base[%c0][%i1], %m1, %v1 :
53+
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
54+
/// ```
55+
///
56+
/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
57+
///
58+
/// Supports vector types with a fixed leading dimension.
59+
struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
60+
using OpRewritePattern::OpRewritePattern;
61+
62+
LogicalResult matchAndRewrite(vector::ScatterOp op,
63+
PatternRewriter &rewriter) const override {
64+
Value indexVec = op.getIndices();
65+
Value maskVec = op.getMask();
66+
Value valueVec = op.getValueToStore();
67+
68+
// Get the vector type from one of the vector operands
69+
VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
70+
if (!vectorTy)
71+
return failure();
72+
73+
auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc,
74+
VectorType subTy, int64_t index) {
75+
int64_t thisIdx[1] = {index};
76+
77+
Value indexSubVec =
78+
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
79+
Value maskSubVec =
80+
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
81+
Value valueSubVec =
82+
vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx);
83+
84+
rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(),
85+
indexSubVec, maskSubVec, valueSubVec,
86+
op.getAlignmentAttr());
87+
88+
// Return a dummy value since unrollVectorOp expects a Value
89+
return rewriter.create<ub::PoisonOp>(loc, subTy);
90+
};
91+
92+
return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
93+
}
94+
};
95+
96+
} // namespace
97+
98+
void mlir::vector::populateVectorScatterLoweringPatterns(
99+
RewritePatternSet &patterns, PatternBenefit benefit) {
100+
patterns.add<UnrollScatter>(patterns.getContext(), benefit);
101+
}

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -431,27 +431,51 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
431431
}
432432

433433
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
434-
vector::UnrollVectorOpFn unrollFn) {
435-
assert(op->getNumResults() == 1 && "expected single result");
436-
assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
437-
VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
438-
if (resultTy.getRank() < 2)
434+
vector::UnrollVectorOpFn unrollFn,
435+
VectorType vectorTy) {
436+
// If vector type is not provided, get it from the result
437+
if (!vectorTy) {
438+
if (op->getNumResults() != 1)
439+
return rewriter.notifyMatchFailure(
440+
op, "expected single result when vector type not provided");
441+
442+
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
443+
if (!vectorTy)
444+
return rewriter.notifyMatchFailure(op, "expected vector type");
445+
}
446+
447+
if (vectorTy.getRank() < 2)
439448
return rewriter.notifyMatchFailure(op, "already 1-D");
440449

441450
// Unrolling doesn't take vscale into account. Pattern is disabled for
442451
// vectors with leading scalable dim(s).
443-
if (resultTy.getScalableDims().front())
452+
if (vectorTy.getScalableDims().front())
444453
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
445454

446455
Location loc = op->getLoc();
447-
Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
448-
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
449456

450-
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
457+
// Only create result value if the operation produces results
458+
Value result;
459+
if (op->getNumResults() > 0) {
460+
result = ub::PoisonOp::create(rewriter, loc, vectorTy);
461+
}
462+
463+
VectorType subTy = VectorType::Builder(vectorTy).dropDim(0);
464+
465+
for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
451466
Value subVector = unrollFn(rewriter, loc, subTy, i);
452-
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
467+
468+
// Only insert if we have a result to build
469+
if (op->getNumResults() > 0) {
470+
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
471+
}
472+
}
473+
474+
if (op->getNumResults() > 0) {
475+
rewriter.replaceOp(op, result);
476+
} else {
477+
rewriter.eraseOp(op);
453478
}
454479

455-
rewriter.replaceOp(op, result);
456480
return success();
457481
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
16431643
// vector.scatter
16441644
//===----------------------------------------------------------------------===//
16451645

1646-
// Multi-Dimensional scatters are not supported yet. Check that we do not lower
1647-
// them.
1648-
16491646
func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
16501647
%0 = arith.constant 0: index
16511648
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
@@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
16541651
}
16551652

16561653
// CHECK-LABEL: func @scatter_with_mask
1657-
// CHECK: vector.scatter
1654+
// CHECK: llvm.extractvalue {{.*}}[0]
1655+
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
1656+
// CHECK: llvm.extractvalue {{.*}}[1]
1657+
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
1658+
// CHECK-NOT: vector.scatter
16581659

16591660
// -----
16601661

@@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
16691670
}
16701671

16711672
// CHECK-LABEL: func @scatter_with_mask_scalable
1672-
// CHECK: vector.scatter
1673+
// CHECK: llvm.extractvalue {{.*}}[0]
1674+
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
1675+
// CHECK: llvm.extractvalue {{.*}}[1]
1676+
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
1677+
// CHECK-NOT: vector.scatter
16731678

16741679
// -----
16751680

0 commit comments

Comments
 (0)