Skip to content

Commit 6f12f2f

Browse files
authored
Remove vnni and transpose logic from XeTileToXeGPU (#842)
1 parent ba6721e commit 6f12f2f

File tree

9 files changed

+286
-403
lines changed

9 files changed

+286
-403
lines changed

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 21 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
167167
mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes,
168168
mlir::DenseI64ArrayAttr outBlkSizes,
169169
llvm::ArrayRef<int64_t> inGrids,
170-
llvm::ArrayRef<int64_t> outGrids, bool isVnniFormat = false,
171-
bool isForDPASB = false) {
170+
llvm::ArrayRef<int64_t> outGrids) {
172171

173172
// handle based on the dim0, and save results into intermediates
174173
llvm::SmallVector<mlir::Value> intermediates(outGrids[0] * inGrids[1]);
@@ -190,18 +189,11 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
190189
}
191190
}
192191
}
193-
} else { // do extract on dim0 using vector::ExtractStridedSliceOp
192+
} else {
193+
// do extract on dim0 using vector::ExtractStridedSliceOp
194194
// intermediates.resize(outGrids[0] * inGrids[1]);
195195
llvm::SmallVector<int64_t> blkSizes({outBlkSizes[0], inBlkSizes[1]});
196-
// if the vnni transform applied, vector shape
197-
// and offset need to be adjusted accordingly.
198-
if (isVnniFormat) {
199-
auto vnniAxis = isForDPASB ? 0 : 1;
200-
auto factor = mlir::cast<mlir::VectorType>(inputs.front().getType())
201-
.getShape()
202-
.back();
203-
blkSizes[vnniAxis] /= factor;
204-
}
196+
205197
// each vector will be horizonally cut into `nums` subvectors
206198
auto nums = outGrids[0] / inGrids[0];
207199
llvm::SmallVector<int64_t> strides({1, 1});
@@ -244,15 +236,6 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
244236
}
245237
} else { // doing extract on dim 1
246238
llvm::SmallVector<int64_t> blkSizes({outBlkSizes[0], outBlkSizes[1]});
247-
// if vnni transform applied, vector shape
248-
// and offset needs to adjusted accordingly.
249-
if (isVnniFormat) {
250-
auto vnniAxis = isForDPASB ? 0 : 1;
251-
auto factor = mlir::cast<mlir::VectorType>(inputs.front().getType())
252-
.getShape()
253-
.back();
254-
blkSizes[vnniAxis] /= factor;
255-
}
256239
llvm::SmallVector<int64_t> strides({1, 1});
257240
auto nums = outGrids[1] / interGrids[1];
258241
for (auto i = 0; i < interGrids[0]; i++) {
@@ -289,14 +272,6 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
289272
auto inGrids = inTy.getShape().take_front(2);
290273
auto inBlkSizes = op.getInnerBlocksAttr();
291274

292-
// specific attention needed for vectors in vnni format,
293-
// which is applied to load for dpas.
294-
auto loadOp = op.getInVec().getDefiningOp<xetile::LoadTileOp>();
295-
auto elemTy = op.getInVec().getType().getElementType();
296-
bool isDpasB = loadOp && isForDPASB(loadOp);
297-
bool isVnniFormat =
298-
isDpasB && elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 32;
299-
300275
// the default grids used as outGrids when unpack is not paired with a pack
301276
int64_t defautlOutGrids[2] = {1, 1};
302277
llvm::ArrayRef<int64_t> outGrids;
@@ -313,20 +288,9 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
313288
outBlkSizes = mlir::DenseI64ArrayAttr::get(ctx, outTy.getShape());
314289
}
315290

316-
// TODO: logically it is to do concat, but the data is in vnni format
317-
// which breaks the concat logic, it transforms concat into stack.
318-
if (isVnniFormat && (inBlkSizes[1] < outBlkSizes[1])) {
319-
return op->emitOpError("[Unexpected rare case]: ")
320-
<< "It rarly happens that we need to do concat on vnni "
321-
<< "transformed vectors (which is 3D instead of 2D). "
322-
<< "It is essentially a stack on the 2nd dim, and is "
323-
<< "not implemented yet.\n";
324-
}
325-
326291
rewriter.setInsertionPoint(op);
327-
auto newOps =
328-
lowerUnpackOrPack(rewriter, op, inputs, inBlkSizes, outBlkSizes,
329-
inGrids, outGrids, isVnniFormat, isDpasB);
292+
auto newOps = lowerUnpackOrPack(rewriter, op, inputs, inBlkSizes,
293+
outBlkSizes, inGrids, outGrids);
330294

331295
if (op->hasOneUse() && packOp) { // lowered Unpack and Pack as pair
332296
rewriter.replaceOp(packOp, newOps);
@@ -418,19 +382,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
418382
if (!innerBlocks || innerBlocks.size() != 2)
419383
return op.emitOpError("Missing valid innerBlock for the tile in op.");
420384

421-
bool hasColMajorTraversal =
422-
tileTy.getOrder().asArrayRef() == mlir::ArrayRef({0, 1});
423385
// Need to make a copy, so we can swap values.
424386
auto innerBlk = llvm::to_vector(innerBlocks.asArrayRef());
425-
// If order is [1, 0] (source memref is col-major), we need to swap the
426-
// shape and innerBlocks because XeGPU ops only support row-major tile
427-
// creation.
428-
if (hasColMajorTraversal) {
429-
assert(op.isSourceMemRef() && op.sourceMemRefHasStaticShape() &&
430-
"Expecting a static shape source memref.");
431-
std::swap(innerBlk[0], innerBlk[1]);
432-
std::swap(shape[0], shape[1]);
433-
}
434387

435388
// using array_length for load if dim1 of innerBlocks is smaller than
436389
// dim1 of shape.
@@ -440,22 +393,6 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
440393
? getBlockArrayLength(op, elemTy, innerBlk[1], shape[1])
441394
: 1;
442395

443-
// If the source memref is col-major we need to convert into a row-major
444-
// because at XeGPU level we only support row-major memrefs. Also
445-
// array_length must be 1 if col-major memrefs are used as the source.
446-
if (hasColMajorTraversal) {
447-
array_length = 1;
448-
// create a memref.reinterpret_cast to convert col-major to row-major
449-
auto rowMajorSourceShape =
450-
swapLastTwoElements(op.getSourceMemrefStaticShape());
451-
auto rowMajorSourceStrides = defaultStrides(rowMajorSourceShape);
452-
int64_t rowMajorSourceOffset = 0;
453-
auto newMemRefTy = mlir::MemRefType::get(rowMajorSourceShape, elemTy);
454-
source = rewriter.create<mlir::memref::ReinterpretCastOp>(
455-
loc, newMemRefTy, source, rowMajorSourceOffset, rowMajorSourceShape,
456-
rowMajorSourceStrides);
457-
}
458-
459396
auto width = array_length * innerBlk[1];
460397

461398
llvm::SmallVector<int64_t, 2> blocks(
@@ -476,8 +413,6 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
476413
// For col-major memref initial offsets need to be swapped.
477414
auto offsetsY = offsets.pop_back_val();
478415
auto offsetsX = offsets.pop_back_val();
479-
if (hasColMajorTraversal)
480-
std::swap(offsetsX, offsetsY);
481416

482417
auto tDescTy = mlir::xegpu::TensorDescType::get(
483418
innerBlk, elemTy, array_length, true /*boundary_check*/, memoryScope);
@@ -513,12 +448,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
513448
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
514449
op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/,
515450
tDescOffsets /*offsets*/);
516-
// col-major source memref requires creating the tiles in transposed
517-
// order
518-
if (hasColMajorTraversal)
519-
xegpuOps[blocks[0] * j + i] = createNdOp;
520-
else
521-
xegpuOps[blocks[1] * i + j] = createNdOp;
451+
452+
xegpuOps[blocks[1] * i + j] = createNdOp;
522453
} else {
523454
return mlir::failure();
524455
}
@@ -620,55 +551,21 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
620551
auto sources = adaptor.getSource();
621552

622553
auto ctx = op.getContext();
623-
auto L1 = mlir::xegpu::CachePolicyAttr::get(
624-
ctx, mlir::xegpu::CachePolicy::CACHED);
625-
auto L2 = mlir::xegpu::CachePolicyAttr::get(
626-
ctx, mlir::xegpu::CachePolicy::CACHED);
627-
auto L3 = mlir::xegpu::CachePolicyAttr::get(
628-
ctx, mlir::xegpu::CachePolicy::CACHED);
629-
630-
mlir::UnitAttr vnniAttr = nullptr;
631-
mlir::IntegerAttr transposeBitWidthAttr;
632-
// TODO: move these two into architecture abstracture in future.
633-
const int SIMD_WIDTH_IN_BITS = 32;
634-
int factor = SIMD_WIDTH_IN_BITS / elemTy.getIntOrFloatBitWidth();
635-
// TODO: use uArch for this?
636-
auto isLowPrecision = [](unsigned int width) -> bool {
637-
bool isPowerOf2 = (width & (width - 1)) == 0;
638-
return isPowerOf2 & (width < 32) & (width > 1);
554+
555+
auto getDefaultCachePolicy = [&]() {
556+
return mlir::xegpu::CachePolicyAttr::get(
557+
ctx, mlir::xegpu::CachePolicy::CACHED);
639558
};
640-
// vnni can only be applied when the blockSZ[0] >= factor
641-
// for shape, e.g., 1xN, vnni cannot be applied, since no
642-
// vnni transform available)
643-
if (isForDPASB(op) && factor > 1 && blockSZ[0] >= factor)
644-
vnniAttr = mlir::UnitAttr::get(ctx);
645559

646-
mlir::DenseI64ArrayAttr transposeAttr;
647-
auto srcOrder = tileTy.getOrder();
648-
if (srcOrder.asArrayRef() == mlir::ArrayRef({1, 0})) {
649-
// Nothing to do
650-
} else if (srcOrder.asArrayRef() == mlir::ArrayRef({0, 1})) {
651-
auto elemWidth = elemTy.getIntOrFloatBitWidth();
652-
if (elemWidth == 32) {
653-
transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0});
654-
} else if (isLowPrecision(elemWidth) && vnniAttr) {
655-
transposeAttr = rewriter.getDenseI64ArrayAttr({1, 0});
656-
transposeBitWidthAttr = rewriter.getI32IntegerAttr(32);
657-
vnniAttr = nullptr;
658-
} else {
659-
return ((mlir::PatternRewriter &)rewriter)
660-
.notifyMatchFailure(op, "Unsupported element type for transpose");
661-
}
662-
} else {
663-
return ((mlir::PatternRewriter &)rewriter)
664-
.notifyMatchFailure(op, "Unsupported order");
665-
}
560+
auto L1 = getDefaultCachePolicy();
561+
auto L2 = getDefaultCachePolicy();
562+
auto L3 = getDefaultCachePolicy();
666563

667-
// vnni and transpose are not available for SLM memory scope.
668-
if (tileTy.getMemoryScopeAsInt() == 3) {
669-
vnniAttr = nullptr;
670-
transposeBitWidthAttr = nullptr;
671-
}
564+
// The tile is in col-major order, which should be canonicalized to
565+
// row-major in canonicalization pass.
566+
auto srcOrder = tileTy.getOrder();
567+
if (srcOrder.asArrayRef() != mlir::ArrayRef({1, 0}))
568+
return mlir::failure();
672569

673570
rewriter.setInsertionPoint(op);
674571
llvm::SmallVector<::mlir::Value> xegpuOps;
@@ -678,22 +575,12 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
678575
auto shape = tdescTy.getShape().vec();
679576
auto array_length = tdescTy.getArrayLength();
680577

681-
if (transposeAttr)
682-
std::swap(shape[0], shape[1]);
683-
684-
if (vnniAttr || transposeBitWidthAttr) {
685-
const int axis = 0;
686-
shape[axis] /= factor;
687-
shape.push_back(factor);
688-
}
689-
690578
if (array_length != 1)
691579
shape.insert(shape.begin(), array_length);
692580

693581
auto vectorTy = mlir::VectorType::get(shape, elemTy);
694582
auto ldOp = rewriter.create<mlir::xegpu::LoadNdOp>(
695-
op.getLoc(), vectorTy, src, vnniAttr, transposeAttr,
696-
transposeBitWidthAttr, L1, L2, L3);
583+
op.getLoc(), vectorTy, src, nullptr, nullptr, nullptr, L1, L2, L3);
697584
if (array_length == 1) {
698585
xegpuOps.push_back(ldOp);
699586
} else {

test/Conversion/XeTileToXeGPU/gemm_preop.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ module attributes {gpu.container_module} {
7171
%31 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16>, index, index -> !xetile.tile<32x32xf16>
7272
xegpu.compile_hint
7373

74-
// CHECK-COUNT-16: {{.*}} = xegpu.dpas {{.*}}, {{.*}}, {{.*}} : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32>
74+
// CHECK-COUNT-16: {{.*}} = xegpu.dpas {{.*}}, {{.*}}, {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
7575
%32 = xetile.tile_mma %29, %28, %arg7 : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
7676
xegpu.compile_hint
7777
scf.yield %30, %31, %32 : !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>

0 commit comments

Comments
 (0)