Skip to content

Commit c9bd5aa

Browse files
committed
Find suitable multiple
1 parent a6c9e49 commit c9bd5aa

File tree

3 files changed

+80
-15
lines changed

3 files changed

+80
-15
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace mlir {
2929
namespace xegpu {
3030
namespace uArch {
3131

32+
constexpr unsigned generalPackedFormatBitSize{32};
33+
3234
// An enum class to represent the scope of an instruction
3335
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
3436
enum class InstructionKind {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
570570

571571
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
572572
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
573-
constexpr int packingBitSizeGatherScatter{32};
574-
int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
575-
? packingBitSizeGatherScatter / bitWidth
576-
: 1;
573+
int chunkAlignmentFactor =
574+
bitWidth < xegpu::uArch::generalPackedFormatBitSize
575+
? xegpu::uArch::generalPackedFormatBitSize / bitWidth
576+
: 1;
577577
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
578578
if (scatterAttr) {
579579
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,28 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
204204
using Lattice::Lattice;
205205
};
206206

207+
/// Helper Function to find a proper instruction multiple for the user-supplied
208+
/// sg-level data shape. `candidates` are uArch allowed shapes.
209+
/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
210+
template <typename T>
211+
int getLargestDivisor(T dim, ArrayRef<T> candidates,
212+
ArrayRef<T> candidateMultiples = {}) {
213+
static_assert(std::is_integral<T>::value, "T must be an integer type");
214+
int largest = -1;
215+
SmallVector<T> multiples = {1};
216+
if (!candidateMultiples.empty())
217+
multiples =
218+
SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
219+
for (T candidate : candidates) {
220+
for (T multiple : multiples) {
221+
int value = static_cast<int>(candidate * multiple);
222+
if (value != 0 && dim % value == 0 && value > largest)
223+
largest = value;
224+
}
225+
}
226+
return largest;
227+
}
228+
207229
/// Helper Functions to get default layouts. A `default layout` is a layout that
208230
/// is assigned to a value when the layout is not fixed by some anchor operation
209231
/// (like DPAS).
@@ -482,12 +504,23 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
482504
if (!blockWHC)
483505
prefetch.emitWarning("No known block params found for the element type.");
484506
auto [bWidth, bHeight, bCount] = blockWHC.value();
485-
486507
SmallVector<int> instData;
508+
int instWidth = getLargestDivisor(
509+
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
510+
bCount);
511+
if (instWidth == -1)
512+
prefetch.emitWarning(
513+
"No suitable instruction multiple found for the given shape.");
487514
if (tdescTy.getRank() == 1)
488-
instData = {bWidth.back() * bCount.back()};
489-
else
490-
instData = {bHeight.back(), bWidth.back() * bCount.back()};
515+
instData = {instWidth};
516+
else {
517+
int instHeight = getLargestDivisor(
518+
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
519+
if (instHeight == -1)
520+
prefetch.emitWarning(
521+
"No suitable instruction multiple found for the given shape.");
522+
instData = {instHeight, instWidth};
523+
}
491524
auto prefetchLayout = getDefaultSIMTLayoutInfo(
492525
tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
493526
// Propagate the layout to the source tensor descriptor.
@@ -597,10 +630,22 @@ void LayoutInfoPropagation::visitDpasOp(
597630
const auto *uArchInstruction =
598631
dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
599632
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
633+
634+
const unsigned dataALen = aTy.getShape().front();
635+
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
600636
const int maxALen =
601-
uArchInstruction->getSupportedM(aTy.getElementType()).back();
637+
getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
638+
if (maxALen == -1)
639+
dpas.emitWarning(
640+
"No suitable instruction multiple found for the given shape.");
641+
642+
const unsigned dataBLen = bTy.getShape().back();
643+
auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
602644
const int maxBLen =
603-
uArchInstruction->getSupportedK(bTy.getElementType()).back();
645+
getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
646+
if (maxBLen == -1)
647+
dpas.emitWarning(
648+
"No suitable instruction multiple found for the given shape.");
604649
SmallVector<int> instDataA = {maxALen, subgroupSize};
605650
SmallVector<int> instDataB = {subgroupSize, maxBLen};
606651

@@ -614,8 +659,13 @@ void LayoutInfoPropagation::visitDpasOp(
614659
uArchInstruction->getPackedFormatBitSizeB())));
615660
if (operands.size() > 2) {
616661
VectorType cTy = dpas.getAccType();
662+
const unsigned dataCLen = bTy.getShape().back();
663+
auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
617664
const int maxCLen =
618-
uArchInstruction->getSupportedN(bTy.getElementType()).back();
665+
getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
666+
if (maxCLen == -1)
667+
dpas.emitWarning(
668+
"No suitable instruction multiple found for the given shape.");
619669
SmallVector<int> instDataC = {maxALen, maxCLen};
620670
propagateIfChanged(operands[2],
621671
operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
@@ -634,16 +684,29 @@ void LayoutInfoPropagation::visitStoreNdOp(
634684
dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
635685
uArch->getInstruction(
636686
xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
687+
VectorType dataTy = store.getValueType();
637688
auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
638689
store.getValueType().getElementType());
639690
if (!blockWHC)
640691
store.emitWarning("No known block params found for the element type.");
641692
auto [bWidth, bHeight, bCount] = blockWHC.value();
642693
SmallVector<int> instData;
643-
if (store.getValueType().getRank() == 1)
644-
instData = {bWidth.back() * bCount.back()};
645-
else
646-
instData = {bHeight.back(), bWidth.back() * bCount.back()};
694+
int instWidth = getLargestDivisor(
695+
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
696+
bCount);
697+
if (instWidth == -1)
698+
store.emitWarning(
699+
"No suitable instruction multiple found for the given shape.");
700+
if (dataTy.getRank() == 1)
701+
instData = {instWidth};
702+
else {
703+
int instHeight = getLargestDivisor(
704+
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
705+
if (instHeight == -1)
706+
store.emitWarning(
707+
"No suitable instruction multiple found for the given shape.");
708+
instData = {instHeight, instWidth};
709+
}
647710
LayoutInfo storeLayout =
648711
getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
649712
uArchInstruction->getPackedFormatBitSize());

0 commit comments

Comments
 (0)