Skip to content

Commit 7d1b9ca

Browse files
authored
[mlir][amx] Vector to AMX conversion pass (llvm#151121)
Adds a pass for Vector to AMX operation conversion. Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur.
1 parent 240c454 commit 7d1b9ca

File tree

7 files changed

+653
-0
lines changed

7 files changed

+653
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
7676
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
7777
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
78+
#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
7879
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
7980
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
8081
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
15311531
];
15321532
}
15331533

1534+
//===----------------------------------------------------------------------===//
1535+
// VectorToAMX
1536+
//===----------------------------------------------------------------------===//
1537+
1538+
def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
1539+
let summary = "Lower the operations from the vector dialect into the AMX "
1540+
"dialect";
1541+
let dependentDialects = [
1542+
"affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
1543+
"memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
1544+
];
1545+
}
1546+
15341547
//===----------------------------------------------------------------------===//
15351548
// XeVMToLLVM
15361549
//===----------------------------------------------------------------------===//
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
10+
#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_CONVERTVECTORTOAMX
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Collect a set of patterns to convert from the vector to AMX ops.
22+
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF)
6868
add_subdirectory(TosaToTensor)
6969
add_subdirectory(UBToLLVM)
7070
add_subdirectory(UBToSPIRV)
71+
add_subdirectory(VectorToAMX)
7172
add_subdirectory(VectorToArmSME)
7273
add_subdirectory(VectorToGPU)
7374
add_subdirectory(VectorToLLVM)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRVectorToAMX
2+
VectorToAMX.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRAMXDialect
12+
MLIRAffineUtils
13+
MLIRArithDialect
14+
MLIRLinalgUtils
15+
MLIRMemRefDialect
16+
MLIRSCFDialect
17+
MLIRTransforms
18+
MLIRVectorDialect
19+
)
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
10+
11+
#include "mlir/Dialect/AMX/AMXDialect.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/SCF/IR/SCF.h"
18+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
24+
#include <numeric>
25+
26+
namespace mlir {
27+
#define GEN_PASS_DEF_CONVERTVECTORTOAMX
28+
#include "mlir/Conversion/Passes.h.inc"
29+
} // namespace mlir
30+
31+
using namespace mlir;
32+
33+
namespace {
34+
35+
/// Return true if vector shape is compatible with AMX tiles.
36+
/// The validation accounts for VNNI packing.
37+
static bool verifyAmxShape(VectorType vec) {
38+
// Check overall shape:
39+
// - 2D for plain layout input or output
40+
// - 3D for VNNI packed input
41+
if (vec.getRank() != 2 && vec.getRank() != 3)
42+
return false;
43+
44+
ArrayRef<int64_t> shape = vec.getShape();
45+
int64_t rows = shape[0];
46+
int64_t cols = shape[1];
47+
unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
48+
49+
// 3D shape indicates VNNI packed layout.
50+
if (vec.getRank() == 3) {
51+
int64_t vnniFactor = 32 / elemBitWidth;
52+
if (shape.back() != vnniFactor)
53+
return false;
54+
cols *= vnniFactor;
55+
}
56+
57+
// AMX tile supports up to 16 rows of 64 bytes each.
58+
constexpr unsigned maxRows = 16;
59+
constexpr unsigned maxBitsPerRow = 64 * 8;
60+
return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
61+
}
62+
63+
/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
64+
static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
65+
vector::ContractionOp contractOp) {
66+
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
67+
if (!accType || accType.getRank() != 2)
68+
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
69+
70+
// Expect 3D inputs for VNNI packed data.
71+
VectorType lhsType = contractOp.getLhs().getType();
72+
VectorType rhsType = contractOp.getRhs().getType();
73+
if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
74+
return rewriter.notifyMatchFailure(contractOp,
75+
"Expects lhs and rhs 3D vectors");
76+
77+
// Check if shapes are compatible with AMX tile.
78+
if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
79+
!verifyAmxShape(accType))
80+
return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
81+
82+
// Validate affine maps.
83+
//
84+
// Iterators can be ordered arbitrarily. Indexing map positions are based on
85+
// operands' target shapes.
86+
// The matrix layouts must match the following:
87+
// - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
88+
// - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
89+
// - matrix C - [M]x[N]
90+
SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
91+
AffineMap mapA = indexingMaps[0];
92+
AffineMap mapB = indexingMaps[1];
93+
if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
94+
mapB.getNumResults() != 3)
95+
return rewriter.notifyMatchFailure(contractOp,
96+
"Invalid input indexing maps");
97+
FailureOr<linalg::ContractionDimensions> dims =
98+
linalg::inferContractionDims(indexingMaps);
99+
if (failed(dims))
100+
return rewriter.notifyMatchFailure(contractOp,
101+
"Failed to infer contraction dims");
102+
// Two reduction dimensions are expected:
103+
// - one for the K dimension
104+
// - one for the VNNI factor
105+
if (dims->k.size() != 2)
106+
return rewriter.notifyMatchFailure(contractOp,
107+
"Expected two reduction dims");
108+
assert(dims->m.size() == 1 && dims->n.size() == 1 &&
109+
"Invalid parallel contraction dims");
110+
111+
SmallVector<vector::IteratorType> iteratorTypes =
112+
contractOp.getIteratorTypesArray();
113+
// Check VNNI dim maps - the innermost dim for A and B inputs.
114+
auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
115+
auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
116+
if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
117+
iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
118+
return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
119+
// Check K dim maps - non-transposed row-major layout.
120+
auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
121+
auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
122+
if (!redDimA || !redDimB || redDimA != redDimB ||
123+
iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
124+
return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
125+
// Check M and N dim maps - map to non-transposed output.
126+
AffineMap mapC = indexingMaps[2];
127+
auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
128+
auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
129+
if (!mDimC || !nDimC)
130+
return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
131+
auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
132+
if (!parallelDimA ||
133+
iteratorTypes[parallelDimA.getPosition()] !=
134+
vector::IteratorType::parallel ||
135+
parallelDimA != mDimC)
136+
return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
137+
auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
138+
if (!parallelDimB ||
139+
iteratorTypes[parallelDimB.getPosition()] !=
140+
vector::IteratorType::parallel ||
141+
parallelDimB != nDimC)
142+
return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
143+
144+
return success();
145+
}
146+
147+
/// Validate contraction operands for AMX lowering.
148+
static LogicalResult validateOperands(PatternRewriter &rewriter,
149+
vector::ContractionOp contractOp) {
150+
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
151+
if (!accType)
152+
return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
153+
154+
// Check if operand types are compatible with AMX compute ops.
155+
bool validElemTypes = false;
156+
Type lhsElemType = contractOp.getLhs().getType().getElementType();
157+
Type rhsElemType = contractOp.getRhs().getType().getElementType();
158+
Type accElemType = accType.getElementType();
159+
if (accElemType.isInteger(32)) {
160+
validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
161+
} else if (accElemType.isF32()) {
162+
validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
163+
(lhsElemType.isBF16() && rhsElemType.isBF16());
164+
}
165+
if (!validElemTypes)
166+
return rewriter.notifyMatchFailure(contractOp,
167+
"Invalid combination of operand types");
168+
169+
if (failed(isAmxVnniLayout(rewriter, contractOp)))
170+
return failure();
171+
172+
return success();
173+
}
174+
175+
/// Collapses the two innermost dimensions together.
176+
static Value collapseLastDim(PatternRewriter &rewriter,
177+
TypedValue<MemRefType> memref) {
178+
int64_t rank = memref.getType().getRank();
179+
SmallVector<ReassociationIndices> reassocIndices;
180+
for (auto i : llvm::seq<int64_t>(0, rank - 2))
181+
reassocIndices.push_back({i});
182+
reassocIndices.push_back({rank - 2, rank - 1});
183+
return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
184+
reassocIndices);
185+
}
186+
187+
/// Loads vector values to an AMX tile.
188+
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
189+
TypedValue<VectorType> vec) {
190+
Location loc = vec.getLoc();
191+
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
192+
193+
// Transfer the vector to a tile through an intermediate buffer.
194+
VectorType vecTy = vec.getType();
195+
Value buf = memref::AllocaOp::create(
196+
rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
197+
SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
198+
vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
199+
200+
// Collapse the VNNI dimension in case of packing.
201+
bool isPacked = vecTy.getRank() == 3;
202+
if (isPacked)
203+
buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
204+
205+
ArrayRef<int64_t> shape = vecTy.getShape();
206+
int64_t rows = shape[0];
207+
int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
208+
std::multiplies<int64_t>());
209+
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
210+
211+
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
212+
{zeroIndex, zeroIndex});
213+
}
214+
215+
/// Stores an AMX tile in a vector.
216+
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
217+
TypedValue<amx::TileType> tile) {
218+
Location loc = tile.getLoc();
219+
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
220+
221+
// Transfer the tile to a vector through an intermediate buffer.
222+
amx::TileType tileTy = tile.getType();
223+
Value buf = memref::AllocaOp::create(
224+
rewriter, loc,
225+
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
226+
SmallVector<Value> indices(2, zeroIndex);
227+
amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
228+
229+
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
230+
return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
231+
}
232+
233+
struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
234+
using OpRewritePattern::OpRewritePattern;
235+
236+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
237+
PatternRewriter &rewriter) const override {
238+
Location loc = contractOp.getLoc();
239+
240+
if (contractOp.getKind() != vector::CombiningKind::ADD)
241+
return rewriter.notifyMatchFailure(contractOp,
242+
"Expects add combining kind");
243+
if (failed(validateOperands(rewriter, contractOp)))
244+
return failure();
245+
246+
TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
247+
TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
248+
auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
249+
assert(acc && "Invalid accumulator type");
250+
TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
251+
252+
TypedValue<amx::TileType> tileMul;
253+
if (acc.getType().getElementType().isFloat()) {
254+
tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
255+
lhsTile, rhsTile, accTile);
256+
} else {
257+
tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
258+
lhsTile, rhsTile, accTile);
259+
}
260+
261+
Value res = storeTile(rewriter, tileMul);
262+
rewriter.replaceOp(contractOp, res);
263+
264+
return success();
265+
}
266+
};
267+
268+
struct ConvertVectorToAMXPass
269+
: public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
270+
void runOnOperation() override {
271+
MLIRContext &ctx = getContext();
272+
RewritePatternSet patterns(&ctx);
273+
populateVectorToAMXConversionPatterns(patterns);
274+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
275+
return signalPassFailure();
276+
}
277+
};
278+
279+
} // namespace
280+
281+
void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
282+
patterns.add<ContractionToAMX>(patterns.getContext());
283+
}

0 commit comments

Comments
 (0)