Skip to content

Commit a16211c

Browse files
authored
[mlir][amx] Direct AMX data transfers (#154114)
Extends Vector to AMX conversion to attempt populating AMX tiles directly from memory. When possible, contraction producers and consumers are replaced by AMX tile data transfer operations. This shortens data path by skipping intermediate register loads and stores.
1 parent 58f3b0d commit a16211c

File tree

2 files changed

+515
-14
lines changed

2 files changed

+515
-14
lines changed

mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp

Lines changed: 160 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "mlir/Dialect/AMX/AMXDialect.h"
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13-
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1413
#include "mlir/Dialect/Arith/IR/Arith.h"
1514
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1615
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -21,6 +20,8 @@
2120
#include "mlir/Pass/Pass.h"
2221
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2322

23+
#include "llvm/Support/DebugLog.h"
24+
2425
#include <numeric>
2526

2627
namespace mlir {
@@ -30,6 +31,8 @@ namespace mlir {
3031

3132
using namespace mlir;
3233

34+
#define DEBUG_TYPE "vector-to-amx"
35+
3336
namespace {
3437

3538
/// Return true if vector shape is compatible with AMX tiles.
@@ -49,8 +52,10 @@ static bool verifyAmxShape(VectorType vec) {
4952
// 3D shape indicates VNNI packed layout.
5053
if (vec.getRank() == 3) {
5154
int64_t vnniFactor = 32 / elemBitWidth;
52-
if (shape.back() != vnniFactor)
55+
if (shape.back() != vnniFactor) {
56+
LDBG() << "invalid VNNI packing factor";
5357
return false;
58+
}
5459
cols *= vnniFactor;
5560
}
5661

@@ -60,7 +65,7 @@ static bool verifyAmxShape(VectorType vec) {
6065
return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
6166
}
6267

63-
/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
68+
/// Check if contraction operands are in AMX-compatible packed VNNI layout.
6469
static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
6570
vector::ContractionOp contractOp) {
6671
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
@@ -172,9 +177,9 @@ static LogicalResult validateOperands(PatternRewriter &rewriter,
172177
return success();
173178
}
174179

175-
/// Collapses the two innermost dimensions together.
176-
static Value collapseLastDim(PatternRewriter &rewriter,
177-
TypedValue<MemRefType> memref) {
180+
/// Collapse the two innermost dimensions together.
181+
static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
182+
TypedValue<MemRefType> memref) {
178183
int64_t rank = memref.getType().getRank();
179184
SmallVector<ReassociationIndices> reassocIndices;
180185
for (auto i : llvm::seq<int64_t>(0, rank - 2))
@@ -184,21 +189,148 @@ static Value collapseLastDim(PatternRewriter &rewriter,
184189
reassocIndices);
185190
}
186191

187-
/// Loads vector values to an AMX tile.
192+
/// Attempt to create an AMX tile load/store operation equivalent to the given
193+
/// vector transfer `xfer` op.
194+
/// This approach allows to skip longer route through registers and a temporary
195+
/// buffer otherwise required to move data to/from an AMX tile.
196+
static Operation *
197+
loadStoreFromTransfer(PatternRewriter &rewriter,
198+
VectorTransferOpInterface xferOp, bool isPacked,
199+
TypedValue<amx::TileType> tileToStore = nullptr) {
200+
if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
201+
return nullptr;
202+
if (xferOp.hasOutOfBoundsDim() ||
203+
!xferOp.getPermutationMap().isMinorIdentity())
204+
return nullptr;
205+
206+
// Extra checks in case of a write op.
207+
// Stores must not be packed.
208+
if (isa<vector::TransferWriteOp>(xferOp) &&
209+
(!tileToStore || isPacked ||
210+
tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
211+
return nullptr;
212+
213+
// Check for a memref source buffer.
214+
// AMX data transfer requires at least 2D shape to correctly
215+
// infer stride between rows.
216+
Value base = xferOp.getBase();
217+
auto memTy = dyn_cast<MemRefType>(base.getType());
218+
int64_t memRank = memTy.getRank();
219+
if (!memTy || memRank < 2)
220+
return nullptr;
221+
222+
// Check that the source buffer has enough contiguous elements to load whole
223+
// AMX tile row.
224+
//
225+
// To ensure correctness, the validation is conservative and expects the
226+
// buffer's innermost dimensions to be statically known, equal to or larger
227+
// than the vector row length, and equal to the VNNI dimension if applicable.
228+
//
229+
// This check could be relaxed to accept more arbitrarily shaped buffers as
230+
// long as there are enough contiguous elements to load a whole row.
231+
if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
232+
return nullptr;
233+
VectorType vecTy = xferOp.getVectorType();
234+
ArrayRef<int64_t> vecShape = vecTy.getShape();
235+
ArrayRef<int64_t> memShape = memTy.getShape();
236+
if (memShape.back() == ShapedType::kDynamic ||
237+
memShape.back() < vecShape.back())
238+
return nullptr;
239+
if (isPacked &&
240+
(memShape.back() != vecShape.back() ||
241+
memShape[memShape.size() - 2] == ShapedType::kDynamic ||
242+
memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
243+
return nullptr;
244+
245+
// Load values directly from the buffer to an AMX tile.
246+
PatternRewriter::InsertionGuard g(rewriter);
247+
rewriter.setInsertionPoint(xferOp);
248+
Location loc = xferOp.getLoc();
249+
250+
// Create a subview of the source buffer based on the transfer op to resolve
251+
// offsets.
252+
SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
253+
int64_t vecRank = vecTy.getRank();
254+
assert(memRank >= vecRank &&
255+
"Expects buffer to be the same or greater rank than vector");
256+
SmallVector<int64_t> shape(memRank - vecRank, 1);
257+
shape.append(vecShape.begin(), vecShape.end());
258+
TypedValue<MemRefType> src =
259+
memref::SubViewOp::create(
260+
rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
261+
getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
262+
.getResult();
263+
264+
// Collapse the VNNI dimension in case of packing.
265+
if (isPacked)
266+
src = collapseLastDim(rewriter, src);
267+
int64_t rows = vecShape[0];
268+
int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
269+
std::multiplies<int64_t>());
270+
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
271+
272+
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
273+
SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
274+
275+
Operation *amxTileOp = nullptr;
276+
if (isa<vector::TransferReadOp>(xferOp)) {
277+
amxTileOp =
278+
amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
279+
} else if (isa<vector::TransferWriteOp>(xferOp)) {
280+
amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
281+
tileToStore);
282+
} else {
283+
llvm_unreachable("unsupported vector transfer op");
284+
}
285+
286+
return amxTileOp;
287+
}
288+
289+
/// Attempt to create an AMX tile load operation equivalent to the given
290+
/// vector transfer `readOp`.
291+
/// Returns loaded AMX tile if successful.
292+
static FailureOr<TypedValue<amx::TileType>>
293+
loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
294+
bool isPacked) {
295+
amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296+
loadStoreFromTransfer(rewriter, readOp, isPacked));
297+
if (!loadOp)
298+
return failure();
299+
return loadOp.getRes();
300+
}
301+
302+
/// Attempt to create an AMX tile store operation equivalent to the given
303+
/// vector transfer `writeOp`.
304+
static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
305+
vector::TransferWriteOp writeOp,
306+
TypedValue<amx::TileType> tileToStore) {
307+
return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
308+
tileToStore));
309+
}
310+
311+
/// Load vector values to an AMX tile.
188312
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
189313
TypedValue<VectorType> vec) {
190314
Location loc = vec.getLoc();
191-
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
192315

193-
// Transfer the vector to a tile through an intermediate buffer.
194316
VectorType vecTy = vec.getType();
317+
bool isPacked = vecTy.getRank() == 3;
318+
319+
// Try to load tile directly from vector producer's buffer.
320+
auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321+
FailureOr<TypedValue<amx::TileType>> tile =
322+
loadFromTransfer(rewriter, readOp, isPacked);
323+
if (succeeded(tile))
324+
return *tile;
325+
326+
// Transfer the vector to a tile through an intermediate buffer.
195327
Value buf = memref::AllocaOp::create(
196328
rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
329+
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
197330
SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
198331
vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
199332

200333
// Collapse the VNNI dimension in case of packing.
201-
bool isPacked = vecTy.getRank() == 3;
202334
if (isPacked)
203335
buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
204336

@@ -212,17 +344,17 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
212344
{zeroIndex, zeroIndex});
213345
}
214346

215-
/// Stores an AMX tile in a vector.
347+
/// Store an AMX tile in a vector.
216348
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
217349
TypedValue<amx::TileType> tile) {
218350
Location loc = tile.getLoc();
219-
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
220351

221352
// Transfer the tile to a vector through an intermediate buffer.
222353
amx::TileType tileTy = tile.getType();
223354
Value buf = memref::AllocaOp::create(
224355
rewriter, loc,
225356
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
357+
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
226358
SmallVector<Value> indices(2, zeroIndex);
227359
amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
228360

@@ -258,8 +390,22 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
258390
lhsTile, rhsTile, accTile);
259391
}
260392

261-
Value res = storeTile(rewriter, tileMul);
262-
rewriter.replaceOp(contractOp, res);
393+
// If the contraction result is only written back to memory, try to replace
394+
// the vector op with an AMX store directly.
395+
Value res = contractOp.getResult();
396+
if (res.hasOneUse()) {
397+
auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
398+
LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
399+
if (succeeded(storeRes)) {
400+
rewriter.eraseOp(writeOp);
401+
rewriter.eraseOp(contractOp);
402+
return success();
403+
}
404+
}
405+
406+
// Load the result back into a vector.
407+
Value newResult = storeTile(rewriter, tileMul);
408+
rewriter.replaceOp(contractOp, newResult);
263409

264410
return success();
265411
}

0 commit comments

Comments
 (0)