Skip to content

Commit ae9fe17

Browse files
authored
Update blocking pass to support slm (#838)
Accesses (load/store) to SLM use different instructions and has different size constraints as compared to block load/store operations for global memory.
1 parent a028590 commit ae9fe17

File tree

9 files changed

+268
-64
lines changed

9 files changed

+268
-64
lines changed

include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ class XeOneToNConversion : public XeConversionPattern<TileUsageAnalysis> {
159159
// (convertedTypes.size() == 1) we will reuse the current value. Otherwise,
160160
// it has one-to-n mapping, and the new value should be an
161161
// UnrealizedConversionCastOp.
162-
for (auto &value : remappedValues) {
162+
for (size_t i = 0; i < remappedValues.size(); i++) {
163+
auto value = remappedValues[i];
163164
auto castOp = value.getDefiningOp<mlir::UnrealizedConversionCastOp>();
164-
if (castOp && castOp.getInputs().size() > 1)
165+
auto valueTy = value.getType();
166+
if (castOp && valueTy == op->getOperand(i).getType())
165167
convertedValues.push_back(castOp.getInputs());
166168
else
167-
convertedValues.push_back(value);
169+
convertedValues.push_back(remappedValues[i]);
168170
}
169171

170172
auto sourceOp = llvm::dyn_cast<SourceOp>(op);

include/imex/Dialect/XeTile/IR/XeTileOps.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#ifndef _XETILE_OPS_H_INCLUDED_
1616
#define _XETILE_OPS_H_INCLUDED_
17-
1817
#include <mlir/Dialect/Vector/IR/VectorOps.h>
1918
#include <mlir/IR/BuiltinTypeInterfaces.h>
2019
#include <mlir/IR/BuiltinTypes.h>

include/imex/Utils/XeCommon.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,11 @@ class PropagateAnalysis {
268268

269269
auto *op = getDefineOrParentOp(value);
270270

271-
// stop when meet a function.
272-
if (!op || llvm::isa<mlir::FunctionOpInterface>(op))
271+
// stop when meet a function or ops, e.g., arith.truncf.
272+
// since their source and results could have different bitwidth,
273+
// in which case the block size cannot be propagated.
274+
if (!op || llvm::isa<mlir::FunctionOpInterface>(op) ||
275+
llvm::isa<mlir::CastOpInterface>(op))
273276
continue;
274277

275278
OpAttrMap[value] = attr;

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,10 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
637637
bool isPowerOf2 = (width & (width - 1)) == 0;
638638
return isPowerOf2 & (width < 32) & (width > 1);
639639
};
640-
if (isForDPASB(op) && factor > 1)
640+
// vnni can only be applied when the blockSZ[0] >= factor
641+
// for shape, e.g., 1xN, vnni cannot be applied, since no
642+
// vnni transform available)
643+
if (isForDPASB(op) && factor > 1 && blockSZ[0] >= factor)
641644
vnniAttr = mlir::UnitAttr::get(ctx);
642645

643646
mlir::DenseI64ArrayAttr transposeAttr;
@@ -661,6 +664,12 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
661664
.notifyMatchFailure(op, "Unsupported order");
662665
}
663666

667+
// vnni and transpose are not available for SLM memory scope.
668+
if (tileTy.getMemoryScopeAsInt() == 3) {
669+
vnniAttr = nullptr;
670+
transposeBitWidthAttr = nullptr;
671+
}
672+
664673
rewriter.setInsertionPoint(op);
665674
llvm::SmallVector<::mlir::Value> xegpuOps;
666675
for (auto src : sources) {

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
/// such that each pieces can be handled by a hardware instruction.
1515
///
1616
//===----------------------------------------------------------------------===//
17-
1817
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
1918
#include <mlir/Dialect/Arith/IR/Arith.h>
2019
#include <mlir/Dialect/Func/IR/FuncOps.h>
@@ -779,66 +778,83 @@ struct InitTileOpPattern
779778
op, "Skipped InitTileOp because the result tile is not rank 2.\n");
780779

781780
auto innerBlocks = tileTy.getInnerBlocks();
781+
auto memorySpace = op.getSourceMemorySpaceAsInt();
782782

783783
// skip it if innerBlocks has been set by user or compiler.
784784
if (innerBlocks)
785785
return mlir::failure();
786786

787787
auto elemTy = tileTy.getElementType();
788788
int elementSize = elemTy.getIntOrFloatBitWidth();
789-
if (isForPrefetch(op)) {
790-
innerBlocks = mlir::DenseI64ArrayAttr::get(
791-
getContext(), getInnerBlockSizes<Prefetch>(
792-
op.getOperation(), elemTy, tileTy.getShape()[0],
793-
tileTy.getShape()[1], this->uArchInterface));
794-
} else if (isForLoad(op)) {
795-
796-
// Set transpose and vnni
797-
bool vnni = false;
798-
bool transpose = false;
799-
800-
auto order = tileTy.getOrder();
801-
if (order[0] == 0 && order[1] == 1)
802-
transpose = true;
803-
804-
for (auto user : getEffectiveUsers(op)) {
805-
if (auto loadTileOp = llvm::dyn_cast<xetile::LoadTileOp>(user)) {
806-
if (isForDPASB(loadTileOp) && elementSize < 32) {
807-
vnni = true;
808-
break;
789+
790+
if (memorySpace == 3) { // for shared memory
791+
const unsigned int lscConstraints = 512; // 512 bytes constraint by lsc
792+
const unsigned int subgroupSize = 16;
793+
auto shape = tileTy.getShape();
794+
int64_t innerBlockSizes[2];
795+
// prefer to use gather loads with 16 simd lanes
796+
innerBlockSizes[0] = shape[0] % subgroupSize == 0 ? 16 : 1;
797+
innerBlockSizes[1] =
798+
(lscConstraints * 8) / (elementSize * innerBlockSizes[0]);
799+
innerBlockSizes[1] =
800+
std::min<int64_t>(innerBlockSizes[1], tileTy.getShape()[1]);
801+
innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlockSizes);
802+
} else { // for global memory
803+
if (isForPrefetch(op)) {
804+
innerBlocks = mlir::DenseI64ArrayAttr::get(
805+
getContext(), getInnerBlockSizes<Prefetch>(
806+
op.getOperation(), elemTy, tileTy.getShape()[0],
807+
tileTy.getShape()[1], this->uArchInterface));
808+
} else if (isForLoad(op)) {
809+
810+
// Set transpose and vnni
811+
bool vnni = false;
812+
bool transpose = false;
813+
814+
auto order = tileTy.getOrder();
815+
if (order[0] == 0 && order[1] == 1)
816+
transpose = true;
817+
818+
for (auto user : getEffectiveUsers(op)) {
819+
if (auto loadTileOp = llvm::dyn_cast<xetile::LoadTileOp>(user)) {
820+
if (isForDPASB(loadTileOp) && elementSize < 32) {
821+
vnni = true;
822+
break;
823+
}
809824
}
810825
}
811-
}
812826

813-
if (vnni && transpose && elementSize < 32) {
814-
int factor = 32 / elementSize;
815-
vnni = false;
816-
llvm::SmallVector<int64_t, 2> innerBlock = getInnerBlockSizes<Load>(
817-
op.getOperation(), mlir::FloatType::getF32(getContext()),
818-
tileTy.getShape()[1], (tileTy.getShape()[0]) / factor,
819-
this->uArchInterface, vnni, transpose);
820-
std::swap(innerBlock[0], innerBlock[1]);
821-
innerBlock[0] *= factor;
822-
innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlock);
823-
824-
} else if (transpose && elementSize < 32) {
825-
return rewriter.notifyMatchFailure(op, "Invalid transpose.");
826-
} else {
827+
if (vnni && transpose && elementSize < 32) {
828+
int factor = 32 / elementSize;
829+
vnni = false;
830+
llvm::SmallVector<int64_t, 2> innerBlock = getInnerBlockSizes<Load>(
831+
op.getOperation(), mlir::FloatType::getF32(getContext()),
832+
tileTy.getShape()[1], (tileTy.getShape()[0]) / factor,
833+
this->uArchInterface, vnni, transpose);
834+
std::swap(innerBlock[0], innerBlock[1]);
835+
innerBlock[0] *= factor;
836+
innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlock);
837+
838+
} else if (transpose && elementSize < 32) {
839+
return rewriter.notifyMatchFailure(op, "Invalid transpose.");
840+
} else {
841+
innerBlocks = mlir::DenseI64ArrayAttr::get(
842+
getContext(),
843+
getInnerBlockSizes<Load>(
844+
op.getOperation(), elemTy, tileTy.getShape()[0],
845+
tileTy.getShape()[1], this->uArchInterface, vnni, transpose));
846+
}
847+
} else if (isForStore(op)) {
827848
innerBlocks = mlir::DenseI64ArrayAttr::get(
828-
getContext(),
829-
getInnerBlockSizes<Load>(op.getOperation(), elemTy,
830-
tileTy.getShape()[0], tileTy.getShape()[1],
831-
this->uArchInterface, vnni, transpose));
849+
getContext(), getInnerBlockSizes<Store>(
850+
op.getOperation(), elemTy, tileTy.getShape()[0],
851+
tileTy.getShape()[1], this->uArchInterface));
852+
} else {
853+
return rewriter.notifyMatchFailure(
854+
op,
855+
"The tile is used for multiple purpose. The init-duplicate pass "
856+
"should be run first to resolve this issue.");
832857
}
833-
} else if (isForStore(op)) {
834-
innerBlocks = mlir::DenseI64ArrayAttr::get(
835-
getContext(), getInnerBlockSizes<Store>(
836-
op.getOperation(), elemTy, tileTy.getShape()[0],
837-
tileTy.getShape()[1], this->uArchInterface));
838-
} else {
839-
return rewriter.notifyMatchFailure(
840-
op, "The tile is used for multiple purpose. The init-duplicate pass "
841-
"should be run first to resolve this issue.");
842858
}
843859

844860
if (innerBlocks.empty()) {

lib/Utils/XeArch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
303303
if (auto loadOp = llvm::dyn_cast<mlir::xegpu::LoadNdOp>(op)) {
304304
auto tdescTy = loadOp.getTensorDescType();
305305

306+
// TODO: need more thinking on SLM
307+
if (tdescTy.getMemoryScope() == mlir::xegpu::MemoryScope::SLM)
308+
return mlir::success();
309+
306310
int elementSize = loadOp.getTensorDescType().getElementTypeBitWidth();
307311

308312
LoadStore2DConfig loadParams;
@@ -342,6 +346,10 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
342346
auto tdescTy = storeOp.getTensorDescType();
343347
int elementSize = tdescTy.getElementTypeBitWidth();
344348

349+
// TODO: need more thinking on SLM
350+
if (tdescTy.getMemoryScope() == mlir::xegpu::MemoryScope::SLM)
351+
return mlir::success();
352+
345353
LoadStore2DConfig storeParams;
346354
bool vnni = false;
347355
bool transpose = false;

0 commit comments

Comments
 (0)