1010
1111#include " mlir/Dialect/AMX/AMXDialect.h"
1212#include " mlir/Dialect/Affine/IR/AffineOps.h"
13- #include " mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1413#include " mlir/Dialect/Arith/IR/Arith.h"
1514#include " mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1615#include " mlir/Dialect/MemRef/IR/MemRef.h"
2120#include " mlir/Pass/Pass.h"
2221#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2322
23+ #include " llvm/Support/DebugLog.h"
24+
2425#include < numeric>
2526
2627namespace mlir {
@@ -30,6 +31,8 @@ namespace mlir {
3031
3132using namespace mlir ;
3233
34+ #define DEBUG_TYPE " vector-to-amx"
35+
3336namespace {
3437
3538// / Return true if vector shape is compatible with AMX tiles.
@@ -49,8 +52,10 @@ static bool verifyAmxShape(VectorType vec) {
4952 // 3D shape indicates VNNI packed layout.
5053 if (vec.getRank () == 3 ) {
5154 int64_t vnniFactor = 32 / elemBitWidth;
52- if (shape.back () != vnniFactor)
55+ if (shape.back () != vnniFactor) {
56+ LDBG () << " invalid VNNI packing factor" ;
5357 return false ;
58+ }
5459 cols *= vnniFactor;
5560 }
5661
@@ -60,7 +65,7 @@ static bool verifyAmxShape(VectorType vec) {
6065 return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
6166}
6267
63- // / Checks if contraction operands are in AMX-compatible packed VNNI layout.
68+ // / Check if contraction operands are in AMX-compatible packed VNNI layout.
6469static LogicalResult isAmxVnniLayout (PatternRewriter &rewriter,
6570 vector::ContractionOp contractOp) {
6671 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc ().getType ());
@@ -172,9 +177,9 @@ static LogicalResult validateOperands(PatternRewriter &rewriter,
172177 return success ();
173178}
174179
175- // / Collapses the two innermost dimensions together.
176- static Value collapseLastDim (PatternRewriter &rewriter,
177- TypedValue<MemRefType> memref) {
180+ // / Collapse the two innermost dimensions together.
181+ static TypedValue<MemRefType> collapseLastDim (PatternRewriter &rewriter,
182+ TypedValue<MemRefType> memref) {
178183 int64_t rank = memref.getType ().getRank ();
179184 SmallVector<ReassociationIndices> reassocIndices;
180185 for (auto i : llvm::seq<int64_t >(0 , rank - 2 ))
@@ -184,21 +189,148 @@ static Value collapseLastDim(PatternRewriter &rewriter,
184189 reassocIndices);
185190}
186191
187- // / Loads vector values to an AMX tile.
192+ // / Attempt to create an AMX tile load/store operation equivalent to the given
193+ // / vector transfer `xfer` op.
194+ // / This approach allows to skip longer route through registers and a temporary
195+ // / buffer otherwise required to move data to/from an AMX tile.
196+ static Operation *
197+ loadStoreFromTransfer (PatternRewriter &rewriter,
198+ VectorTransferOpInterface xferOp, bool isPacked,
199+ TypedValue<amx::TileType> tileToStore = nullptr ) {
200+ if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
201+ return nullptr ;
202+ if (xferOp.hasOutOfBoundsDim () ||
203+ !xferOp.getPermutationMap ().isMinorIdentity ())
204+ return nullptr ;
205+
206+ // Extra checks in case of a write op.
207+ // Stores must not be packed.
208+ if (isa<vector::TransferWriteOp>(xferOp) &&
209+ (!tileToStore || isPacked ||
210+ tileToStore.getType ().getShape () != xferOp.getVectorType ().getShape ()))
211+ return nullptr ;
212+
213+ // Check for a memref source buffer.
214+ // AMX data transfer requires at least 2D shape to correctly
215+ // infer stride between rows.
216+ Value base = xferOp.getBase ();
217+ auto memTy = dyn_cast<MemRefType>(base.getType ());
218+ int64_t memRank = memTy.getRank ();
219+ if (!memTy || memRank < 2 )
220+ return nullptr ;
221+
222+ // Check that the source buffer has enough contiguous elements to load whole
223+ // AMX tile row.
224+ //
225+ // To ensure correctness, the validation is conservative and expects the
226+ // buffer's innermost dimensions to be statically known, equal to or larger
227+ // than the vector row length, and equal to the VNNI dimension if applicable.
228+ //
229+ // This check could be relaxed to accept more arbitrarily shaped buffers as
230+ // long as there are enough contiguous elements to load a whole row.
231+ if (!memTy.areTrailingDimsContiguous (isPacked ? 2 : 1 ))
232+ return nullptr ;
233+ VectorType vecTy = xferOp.getVectorType ();
234+ ArrayRef<int64_t > vecShape = vecTy.getShape ();
235+ ArrayRef<int64_t > memShape = memTy.getShape ();
236+ if (memShape.back () == ShapedType::kDynamic ||
237+ memShape.back () < vecShape.back ())
238+ return nullptr ;
239+ if (isPacked &&
240+ (memShape.back () != vecShape.back () ||
241+ memShape[memShape.size () - 2 ] == ShapedType::kDynamic ||
242+ memShape[memShape.size () - 2 ] < vecShape[vecShape.size () - 2 ]))
243+ return nullptr ;
244+
245+ // Load values directly from the buffer to an AMX tile.
246+ PatternRewriter::InsertionGuard g (rewriter);
247+ rewriter.setInsertionPoint (xferOp);
248+ Location loc = xferOp.getLoc ();
249+
250+ // Create a subview of the source buffer based on the transfer op to resolve
251+ // offsets.
252+ SmallVector<OpFoldResult> strides (memRank, rewriter.getIndexAttr (1 ));
253+ int64_t vecRank = vecTy.getRank ();
254+ assert (memRank >= vecRank &&
255+ " Expects buffer to be the same or greater rank than vector" );
256+ SmallVector<int64_t > shape (memRank - vecRank, 1 );
257+ shape.append (vecShape.begin (), vecShape.end ());
258+ TypedValue<MemRefType> src =
259+ memref::SubViewOp::create (
260+ rewriter, loc, base, getAsOpFoldResult (xferOp.getIndices ()),
261+ getAsOpFoldResult (rewriter.getI64ArrayAttr (shape)), strides)
262+ .getResult ();
263+
264+ // Collapse the VNNI dimension in case of packing.
265+ if (isPacked)
266+ src = collapseLastDim (rewriter, src);
267+ int64_t rows = vecShape[0 ];
268+ int64_t cols = std::accumulate (vecShape.begin () + 1 , vecShape.end (), 1 ,
269+ std::multiplies<int64_t >());
270+ auto tileType = amx::TileType::get ({rows, cols}, vecTy.getElementType ());
271+
272+ Value zeroIndex = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
273+ SmallVector<Value> tileIndicides (src.getType ().getRank (), zeroIndex);
274+
275+ Operation *amxTileOp = nullptr ;
276+ if (isa<vector::TransferReadOp>(xferOp)) {
277+ amxTileOp =
278+ amx::TileLoadOp::create (rewriter, loc, tileType, src, tileIndicides);
279+ } else if (isa<vector::TransferWriteOp>(xferOp)) {
280+ amxTileOp = amx::TileStoreOp::create (rewriter, loc, src, tileIndicides,
281+ tileToStore);
282+ } else {
283+ llvm_unreachable (" unsupported vector transfer op" );
284+ }
285+
286+ return amxTileOp;
287+ }
288+
289+ // / Attempt to create an AMX tile load operation equivalent to the given
290+ // / vector transfer `readOp`.
291+ // / Returns loaded AMX tile if successful.
292+ static FailureOr<TypedValue<amx::TileType>>
293+ loadFromTransfer (PatternRewriter &rewriter, vector::TransferReadOp readOp,
294+ bool isPacked) {
295+ amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296+ loadStoreFromTransfer (rewriter, readOp, isPacked));
297+ if (!loadOp)
298+ return failure ();
299+ return loadOp.getRes ();
300+ }
301+
302+ // / Attempt to create an AMX tile store operation equivalent to the given
303+ // / vector transfer `writeOp`.
304+ static LogicalResult storeFromTransfer (PatternRewriter &rewriter,
305+ vector::TransferWriteOp writeOp,
306+ TypedValue<amx::TileType> tileToStore) {
307+ return success (loadStoreFromTransfer (rewriter, writeOp, /* isPacked=*/ false ,
308+ tileToStore));
309+ }
310+
311+ // / Load vector values to an AMX tile.
188312static TypedValue<amx::TileType> loadTile (PatternRewriter &rewriter,
189313 TypedValue<VectorType> vec) {
190314 Location loc = vec.getLoc ();
191- Value zeroIndex = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
192315
193- // Transfer the vector to a tile through an intermediate buffer.
194316 VectorType vecTy = vec.getType ();
317+ bool isPacked = vecTy.getRank () == 3 ;
318+
319+ // Try to load tile directly from vector producer's buffer.
320+ auto readOp = vec.getDefiningOp <vector::TransferReadOp>();
321+ FailureOr<TypedValue<amx::TileType>> tile =
322+ loadFromTransfer (rewriter, readOp, isPacked);
323+ if (succeeded (tile))
324+ return *tile;
325+
326+ // Transfer the vector to a tile through an intermediate buffer.
195327 Value buf = memref::AllocaOp::create (
196328 rewriter, loc, MemRefType::get (vecTy.getShape (), vecTy.getElementType ()));
329+ Value zeroIndex = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
197330 SmallVector<Value> indices (vecTy.getRank (), zeroIndex);
198331 vector::TransferWriteOp::create (rewriter, loc, vec, buf, indices);
199332
200333 // Collapse the VNNI dimension in case of packing.
201- bool isPacked = vecTy.getRank () == 3 ;
202334 if (isPacked)
203335 buf = collapseLastDim (rewriter, cast<TypedValue<MemRefType>>(buf));
204336
@@ -212,17 +344,17 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
212344 {zeroIndex, zeroIndex});
213345}
214346
215- // / Stores an AMX tile in a vector.
347+ // / Store an AMX tile in a vector.
216348static TypedValue<VectorType> storeTile (PatternRewriter &rewriter,
217349 TypedValue<amx::TileType> tile) {
218350 Location loc = tile.getLoc ();
219- Value zeroIndex = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
220351
221352 // Transfer the tile to a vector through an intermediate buffer.
222353 amx::TileType tileTy = tile.getType ();
223354 Value buf = memref::AllocaOp::create (
224355 rewriter, loc,
225356 MemRefType::get (tileTy.getShape (), tileTy.getElementType ()));
357+ Value zeroIndex = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
226358 SmallVector<Value> indices (2 , zeroIndex);
227359 amx::TileStoreOp::create (rewriter, loc, buf, indices, tile);
228360
@@ -258,8 +390,22 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
258390 lhsTile, rhsTile, accTile);
259391 }
260392
261- Value res = storeTile (rewriter, tileMul);
262- rewriter.replaceOp (contractOp, res);
393+ // If the contraction result is only written back to memory, try to replace
394+ // the vector op with an AMX store directly.
395+ Value res = contractOp.getResult ();
396+ if (res.hasOneUse ()) {
397+ auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers ().begin ());
398+ LogicalResult storeRes = storeFromTransfer (rewriter, writeOp, tileMul);
399+ if (succeeded (storeRes)) {
400+ rewriter.eraseOp (writeOp);
401+ rewriter.eraseOp (contractOp);
402+ return success ();
403+ }
404+ }
405+
406+ // Load the result back into a vector.
407+ Value newResult = storeTile (rewriter, tileMul);
408+ rewriter.replaceOp (contractOp, newResult);
263409
264410 return success ();
265411 }
0 commit comments