@@ -764,7 +764,9 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
764764 if (vnniConf) {
765765 vecLoadType = getVnniVector (tileType.getShape (), tileType.getElementType (),
766766 *vnniConf);
767- packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
767+ if (!transpose_bit) {
768+ packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
769+ }
768770 }
769771 SmallVector<Value> loadVec;
770772 for (auto tile : loadTiles) {
@@ -1165,7 +1167,6 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
11651167 if (vnniFactor == -1 )
11661168 return failure ();
11671169
1168- VnniConfig vnniConfA{.vnniFactor = vnniFactor, .vnniAxis = 1 };
11691170 VnniConfig vnniConfB{.vnniFactor = vnniFactor, .vnniAxis = 0 };
11701171
11711172 // Load A sub-tiles.
@@ -1214,7 +1215,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
12141215 // Extract DPAS tiles from loaded sub-tiles.
12151216 TilesArray dpasVecA = extractVecSubTiles (rewriter, loc, loadVecA,
12161217 {dimM, kTile }, tileTypeA.getShape (),
1217- {dpasTileM, dpasTileK}, vnniConfA );
1218+ {dpasTileM, dpasTileK});
12181219 TilesArray dpasVecB = extractVecSubTiles (rewriter, loc, loadVecB,
12191220 {kTile , dimN}, tileTypeB.getShape (),
12201221 {dpasTileK, dpasTileN}, vnniConfB);
@@ -1629,7 +1630,7 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
16291630 using LinalgToXeGPUBase::LinalgToXeGPUBase;
16301631
16311632 void runOnOperation () override {
1632- LinalgToXeGPUOptions options{kTile , stages, dpasTile};
1633+ LinalgToXeGPUOptions options{kTile , stages, SmallVector< int64_t >{ dpasTile. begin (), dpasTile. end ()} };
16331634
16341635 // Run GEMM pattern first to allow fusion with its consumers.
16351636 RewritePatternSet gemmPatterns (&getContext ());
0 commit comments