Skip to content

Commit 7cbeebb

Browse files
committed
small bug fixes
1 parent 66ab4aa commit 7cbeebb

File tree

2 files changed

+110
-64
lines changed

2 files changed

+110
-64
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 103 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
267267
return success();
268268
}
269269
};
270-
270+
/*
271271
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
272272
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273273
LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
@@ -298,6 +298,49 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
298298
return success();
299299
}
300300
};
301+
*/
302+
303+
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
304+
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
305+
LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
306+
PatternRewriter &rewriter) const override {
307+
Location loc = op.getLoc();
308+
VectorType valueTy = op.getValueType();
309+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
310+
311+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
312+
LDBG("UnrollStoreNdOp: targetShape present? " << (targetShape.has_value() ? "yes" : "no"));
313+
if (!targetShape)
314+
return failure();
315+
316+
LDBG("targetShape: ");
317+
for (auto v : *targetShape) LDBG(" " << v);
318+
319+
SmallVector<Type> convertedValTypes =
320+
getUnrolledTypes(valueTy, *targetShape);
321+
LDBG("convertedValTypes size: " << convertedValTypes.size());
322+
SmallVector<Type> convertedTdescTypes =
323+
getUnrolledTypes(tdescTy, *targetShape);
324+
LDBG("convertedTdescTypes size: " << convertedTdescTypes.size());
325+
326+
SmallVector<Value> convertedValues =
327+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
328+
LDBG("convertedValues size: " << convertedValues.size());
329+
SmallVector<Value> convertedTdescs = pack(
330+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
331+
LDBG("convertedTdescs size: " << convertedTdescs.size());
332+
333+
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
334+
LDBG("Creating StoreNdOp with value: " << v << ", tdesc: " << t);
335+
rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
336+
op.getL2HintAttr(), op.getL3HintAttr());
337+
}
338+
339+
LDBG("Erasing original StoreNdOp: " << op);
340+
rewriter.eraseOp(op);
341+
return success();
342+
}
343+
};
301344

302345
struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
303346
using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
@@ -402,37 +445,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
402445
PatternRewriter &rewriter) const override {
403446
Location loc = op.getLoc();
404447
xegpu::TensorDescType tdescTy = op.getType();
448+
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
449+
VectorType indiceVecTy = indiceVec.getType();
405450

406-
// check if the tensor descriptor type is a 1d vector type
407-
if (tdescTy.getRank() > 2)
451+
if (!tdescTy.isScattered())
408452
return failure();
409-
453+
410454
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
411455
if (!targetShape)
412456
return failure();
457+
458+
SmallVector<int64_t> targetIndiceShape(*targetShape);
459+
int64_t originalChunkSize = tdescTy.getChunkSize();
460+
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
461+
if (originalChunkSize > 1)
462+
targetIndiceShape.pop_back();
413463

414464
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
415-
416-
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
417-
VectorType indiceVecTy = indiceVec.getType();
418-
419-
SmallVector<Type> convertedIndiceTypes;
420-
SmallVector<Value> convertedIndiceVec;
465+
SmallVector<Type> convertedIndiceTypes =
466+
getUnrolledTypes(indiceVecTy, targetIndiceShape);
467+
SmallVector<Value> convertedIndiceVec =
468+
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
469+
421470
SmallVector<Value> newOps;
422471

423-
if (tdescTy.getRank() == 2) {
424-
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
425-
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
426-
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
427-
428-
int64_t wholeChunk = tdescTy.getShape().back();
429-
int64_t blockedChunk = targetShape->back();
430-
int64_t numInnerLoops = wholeChunk / blockedChunk;
472+
// more indices is need when chunkSize > 1. Since a big load from one
473+
// address could be break into multiple small loads.
474+
if (originalChunkSize > 1) {
475+
int64_t blockedChunkSize = targetShape->back();
476+
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
431477

432478
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
433-
for (int64_t i = 0; i < numInnerLoops; ++i) {
479+
for (int64_t i = 0; i < numNewChunks; ++i) {
434480
// Compute the offset
435-
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
481+
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
436482
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
437483
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
438484

@@ -443,8 +489,6 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
443489
}
444490
}
445491
} else {
446-
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
447-
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
448492
for (auto indice : convertedIndiceVec) {
449493
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
450494
op.getSource(), indice);
@@ -468,15 +512,17 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
468512
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
469513
xegpu::TensorDescType tdescTy = op.getTensorDescType();
470514

471-
// check if the tensor descriptor type is a 1d vector type
472-
if (tdescTy.getRank() > 2)
515+
if (!tdescTy.isScattered())
473516
return failure();
474517

475-
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
476-
477518
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
478519
if (!targetShape)
479520
return failure();
521+
522+
SmallVector<int64_t> targetMaskShape(*targetShape);
523+
int64_t originalChunkSize = tdescTy.getChunkSize();
524+
525+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
480526

481527
Type elemTy = tdescTy.getElementType();
482528
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -489,25 +535,26 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
489535
SmallVector<Type> convertedMaskTypes;
490536
SmallVector<Value> convertedMasks;
491537

492-
if (tdescTy.getRank() == 2) {
493-
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
494-
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
495-
int64_t wholeChunk = tdescTy.getShape().back();
496-
int64_t blockedChunk = targetShape->back();
497-
int64_t numInnerLoops = wholeChunk / blockedChunk;
538+
if (originalChunkSize > 1) {
539+
targetMaskShape.pop_back();
540+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
541+
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
542+
int64_t blockedChunkSize = targetShape->back();
543+
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
498544

499545
for (auto mask : convertedMasks1D) {
500-
for (int64_t i = 0; i < numInnerLoops; ++i) {
546+
for (int64_t i = 0; i < numNewChunks; ++i) {
501547
convertedMasks.push_back(mask);
502548
}
503549
}
550+
// This is to handle the transpose effect when chunkSize > 1.
504551
if (targetShape && targetShape->size() > 1) {
505552
std::swap((*targetShape)[0], (*targetShape)[1]);
506553
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
507554
}
508555
} else {
509-
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
510-
convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
556+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
557+
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
511558
}
512559

513560
SmallVector<Value> newOps;
@@ -561,38 +608,38 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
561608
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
562609
xegpu::TensorDescType tdescTy = op.getTensorDescType();
563610

564-
// check if the tensor descriptor type is a 1d vector type
565-
if (tdescTy.getRank() > 2)
611+
if (!tdescTy.isScattered())
566612
return failure();
567613

568-
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
569-
570614
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
571615
if (!targetShape)
572616
return failure();
573-
617+
618+
SmallVector<int64_t> targetIndiceShape(*targetShape);
619+
int64_t originalChunkSize = tdescTy.getChunkSize();
620+
621+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
622+
574623
SmallVector<Type> convertedTdescTypes =
575624
getUnrolledTypes(tdescTy, *targetShape);
576625
SmallVector<Value> convertedTdescs = pack(
577626
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
578627

579-
580628
SmallVector<Type> convertedMaskTypes;
581629
SmallVector<Value> convertedMasks;
582630

583-
if (tdescTy.getRank() == 2) {
631+
if (originalChunkSize > 1) {
632+
int64_t blockedChunkSize = targetShape->back();
633+
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
584634
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
585635
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
586-
int64_t wholeChunk = tdescTy.getShape().back();
587-
int64_t blockedChunk = targetShape->back();
588-
int64_t numInnerLoops = wholeChunk / blockedChunk;
589636

590637
for (auto mask : convertedMasks1D) {
591-
for (int64_t i = 0; i < numInnerLoops; ++i) {
638+
for (int64_t i = 0; i < numNewChunks; ++i) {
592639
convertedMasks.push_back(mask);
593640
}
594641
}
595-
642+
// This is to handle the transpose effect when chunkSize > 1.
596643
std::swap((*targetShape)[0], (*targetShape)[1]);
597644

598645
} else {
@@ -626,8 +673,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
626673
Location loc = op.getLoc();
627674
xegpu::TensorDescType tdescTy = op.getTensorDescType();
628675

629-
// check if the tensor descriptor type is a 1d vector type
630-
if (tdescTy.getRank() > 2)
676+
if (tdescTy.getRank() >2)
677+
return failure();
678+
679+
if (!tdescTy.isScattered())
631680
return failure();
632681

633682
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -644,18 +693,17 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
644693
SmallVector<Type> convertedOffsetTypes;
645694
SmallVector<Value> convertedOffsetVec;
646695
SmallVector<Value> newOps;
647-
648-
if (tdescTy.getRank() == 2) {
696+
int64_t originalChunkSize = tdescTy.getChunkSize();
697+
if (originalChunkSize > 1) {
649698
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
650699
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
651700
SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
652701

653-
int64_t wholeChunk = tdescTy.getShape().back();
654-
int64_t blockedChunk = targetShape->back();
655-
int64_t numInnerLoops = wholeChunk / blockedChunk;
702+
int64_t blockedChunkSize = targetShape->back();
703+
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
656704

657705
for (auto offset : convertedOffsetVec1D) {
658-
for (int64_t i = 0; i < numInnerLoops; ++i) {
706+
for (int64_t i = 0; i < numNewChunks; ++i) {
659707
convertedOffsetVec.push_back(offset);
660708
}
661709
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,23 @@ struct TestXeGPUUnrollingPatterns
105105
Attribute encoding = tdescTy.getEncoding();
106106
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
107107
tdescTy.getLayout());
108-
109-
int64_t newChunkSize = 0;
110-
auto instData = layout.getInstData();
111-
if (!instData.empty())
112-
newChunkSize = instData.asArrayRef().back();
113-
108+
114109
if (layout) {
115110
if (layout.getLaneLayout() == nullptr)
116111
layout = xegpu::LayoutAttr();
117112
else
118113
layout = layout.dropInstData();
119114
}
120115

121-
SmallVector<NamedAttribute> attrs;
122-
auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
123-
if (scatterAttr) {
116+
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
117+
auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
124118
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
125119

126120
if (chunkSize > 1) {
121+
int64_t newChunkSize = chunkSize;
122+
auto instData = layout.getInstData();
123+
if (!instData.empty())
124+
newChunkSize = instData.asArrayRef().back();
127125

128126
auto chunkSizeAttr = mlir::IntegerAttr::get(
129127
mlir::IntegerType::get(ctx, 64), newChunkSize);

0 commit comments

Comments
 (0)