1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
17+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1718#include " mlir/Dialect/Vector/IR/VectorOps.h"
1819#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1920#include " mlir/Pass/Pass.h"
@@ -312,28 +313,6 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
312313 }
313314};
314315
315- static LogicalResult validateDpasIndexing (PatternRewriter &rewriter,
316- vector::ContractionOp contractOp) {
317- MLIRContext *ctx = contractOp.getContext ();
318- SmallVector<AffineMap, 4 > maps = contractOp.getIndexingMapsArray ();
319-
320- // Operand rank defines expected data layout:
321- // - 2D for standard GEMM
322- // - 3D for VNNI layout
323- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
324- auto infer = [&](MapList m) { return AffineMap::inferFromExprList (m, ctx); };
325- AffineExpr m, n, k, vnni;
326- bindDims (ctx, m, n, k, vnni);
327-
328- if (contractOp.getRhsType ().getRank () == 2 ) {
329- // Require plain GEMM without any transposition.
330- return success (maps == infer ({{m, k}, {k, n}, {m, n}}));
331- }
332-
333- // Require VNNI layout.
334- return success (maps == infer ({{m, k, vnni}, {k, n, vnni}, {m, n}}));
335- }
336-
337316struct ContractionLowering : public OpRewritePattern <vector::ContractionOp> {
338317 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
339318
@@ -349,48 +328,30 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
349328 VectorType accType = dyn_cast<VectorType>(acc.getType ());
350329 if (!accType || accType.getRank () != 2 )
351330 return rewriter.notifyMatchFailure (contractOp, " Expects acc 2D vector" );
331+
332+ // Accept only plain 2D data layout.
333+ // VNNI packing is left to later lowering.
352334 TypedValue<VectorType> lhs = contractOp.getLhs ();
353- VectorType lhsType = lhs.getType ();
354- int64_t lhsRank = lhsType.getRank ();
355- if (!(lhsRank == 2 || lhsRank == 3 ))
356- return rewriter.notifyMatchFailure (contractOp,
357- " Expects lhs 2D or 3D vector" );
358335 TypedValue<VectorType> rhs = contractOp.getRhs ();
359- VectorType rhsType = rhs.getType ();
360- int64_t rhsRank = rhsType.getRank ();
361- if (!(rhsRank == 2 || rhsRank == 3 ))
336+ if (lhs.getType ().getRank () != 2 || rhs.getType ().getRank () != 2 )
362337 return rewriter.notifyMatchFailure (contractOp,
363- " Expects rhs 2D or 3D vector" );
364- if (lhsRank != rhsRank)
365- return rewriter.notifyMatchFailure (
366- contractOp, " Expects lhs and rhs to be the same rank" );
338+ " Expects lhs and rhs 2D vectors" );
367339
368- if (failed ( validateDpasIndexing (rewriter, contractOp)))
340+ if (! isRowMajorMatmul ( contractOp. getIndexingMapsAttr ( )))
369341 return rewriter.notifyMatchFailure (contractOp, " Invalid indexing maps" );
370342
371- // 3D shape implies VNNI layout verified by the earlier indexing check.
372- bool isVnni = rhsRank == 3 ;
373- auto rhsShape = rhsType.getShape ();
374- int64_t dimK = isVnni ? rhsShape[0 ] * rhsShape[2 ] : rhsShape[0 ];
375- unsigned elemBitWidth = rhsType.getElementType ().getIntOrFloatBitWidth ();
376- if (dimK != (8 * 32 / elemBitWidth))
343+ // TODO: Update shape validation to be target aware.
344+ auto rhsShape = rhs.getType ().getShape ();
345+ auto accShape = accType.getShape ();
346+ int64_t dimM = accShape[0 ];
347+ int64_t dimN = accShape[1 ];
348+ int64_t dimK = rhsShape[0 ];
349+ if (dimM != 8 || dimN != 16 || dimK % 8 != 0 )
377350 return rewriter.notifyMatchFailure (contractOp,
378- " Invalid K-dimension size" );
379- if (isVnni && rhsShape[2 ] != (32 / elemBitWidth))
380- return rewriter.notifyMatchFailure (contractOp, " Invalid VNNI factor" );
381-
382- if (isVnni) {
383- // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
384- // flat 2D shape for its lhs operand.
385- auto lhsShape = lhsType.getShape ();
386- auto lhsFlatType = VectorType::get (
387- {lhsShape[0 ], lhsShape[1 ] * lhsShape[2 ]}, lhsType.getElementType ());
388- lhs = rewriter.create <vector::ShapeCastOp>(loc, lhsFlatType, lhs)
389- .getResult ();
390- }
351+ " Invalid operand dimensions" );
391352
392353 auto dpasOp = rewriter.create <xegpu::DpasOp>(
393- loc, contractOp.getResultType (), lhs, rhs, acc);
354+ loc, TypeRange{ contractOp.getResultType ()}, ValueRange{ lhs, rhs, acc} );
394355 rewriter.replaceOp (contractOp, dpasOp);
395356
396357 return success ();
0 commit comments