Skip to content

Commit 66ab4aa

Browse files
committed
add unrolling support for load/store/prefetch/update with chunk_size
1 parent 248981b commit 66ab4aa

File tree

1 file changed

+90
-41
lines changed

1 file changed

+90
-41
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -421,43 +421,36 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
421421
SmallVector<Value> newOps;
422422

423423
if (tdescTy.getRank() == 2) {
424-
SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
425-
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
426-
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
427-
// Assume tdescTy, targetShape, and convertedIndiceVec are defined
428-
int64_t outerDim = tdescTy.getShape().back();
429-
int64_t innerDim = targetShape->back();
430-
int64_t numInnerLoops = outerDim / innerDim;
424+
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
425+
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
426+
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
431427

432-
// Get element size in bytes
433-
int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
428+
int64_t wholeChunk = tdescTy.getShape().back();
429+
int64_t blockedChunk = targetShape->back();
430+
int64_t numInnerLoops = wholeChunk / blockedChunk;
434431

435432
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
436433
for (int64_t i = 0; i < numInnerLoops; ++i) {
437434
// Compute the offset
438-
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
435+
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
439436
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
440437
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
441438

442-
auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
443-
auto newOp = rewriter.create<xegpu::CreateDescOp>(
439+
auto newOp = rewriter.create<xegpu::CreateDescOp>(
444440
loc, newTdescTy, op.getSource(), offsetIndice);
445441

446442
newOps.push_back(newOp);
447443
}
448444
}
449-
} else if (tdescTy.getRank() == 1) {
445+
} else {
450446
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
451447
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
452448
for (auto indice : convertedIndiceVec) {
453449
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
454450
op.getSource(), indice);
455451
newOps.push_back(newOp);
456452
}
457-
} else {
458-
// Unsupported rank for tensor descriptor
459-
return failure();
460-
}
453+
}
461454

462455
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
463456
rewriter.replaceOp(op, castOp);
@@ -493,10 +486,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
493486
SmallVector<Value> convertedTdescs = pack(
494487
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
495488

496-
SmallVector<Type> convertedMaskTypes =
497-
getUnrolledTypes(maskTy, *targetShape);
498-
SmallVector<Value> convertedMasks =
499-
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
489+
SmallVector<Type> convertedMaskTypes;
490+
SmallVector<Value> convertedMasks;
491+
492+
if (tdescTy.getRank() == 2) {
493+
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
494+
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
495+
int64_t wholeChunk = tdescTy.getShape().back();
496+
int64_t blockedChunk = targetShape->back();
497+
int64_t numInnerLoops = wholeChunk / blockedChunk;
498+
499+
for (auto mask : convertedMasks1D) {
500+
for (int64_t i = 0; i < numInnerLoops; ++i) {
501+
convertedMasks.push_back(mask);
502+
}
503+
}
504+
if (targetShape && targetShape->size() > 1) {
505+
std::swap((*targetShape)[0], (*targetShape)[1]);
506+
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
507+
}
508+
} else {
509+
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
510+
convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
511+
}
500512

501513
SmallVector<Value> newOps;
502514
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -505,9 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
505517
op.getL2HintAttr(), op.getL3HintAttr());
506518
newOps.push_back(newOp);
507519
}
508-
520+
509521
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
510-
511522
rewriter.replaceOp(op, castOp);
512523
return success();
513524
}
@@ -521,7 +532,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
521532
xegpu::TensorDescType tdescTy = op.getTensorDescType();
522533

523534
// check if the tensor descriptor type is a 1d vector type
524-
if (tdescTy.getRank() > 1)
535+
if (tdescTy.getRank() > 2)
525536
return failure();
526537

527538
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -551,29 +562,48 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
551562
xegpu::TensorDescType tdescTy = op.getTensorDescType();
552563

553564
// check if the tensor descriptor type is a 1d vector type
554-
if (tdescTy.getRank() > 1)
565+
if (tdescTy.getRank() > 2)
555566
return failure();
556567

557568
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
558569

559570
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
560571
if (!targetShape)
561572
return failure();
562-
563-
SmallVector<Type> convertedValTypes =
564-
getUnrolledTypes(valueTy, *targetShape);
573+
565574
SmallVector<Type> convertedTdescTypes =
566575
getUnrolledTypes(tdescTy, *targetShape);
567-
568-
SmallVector<Value> convertedValues =
569-
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
570576
SmallVector<Value> convertedTdescs = pack(
571577
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
572578

573-
SmallVector<Type> convertedMaskTypes =
574-
getUnrolledTypes(maskTy, *targetShape);
575-
SmallVector<Value> convertedMasks =
576-
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
579+
580+
SmallVector<Type> convertedMaskTypes;
581+
SmallVector<Value> convertedMasks;
582+
583+
if (tdescTy.getRank() == 2) {
584+
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
585+
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
586+
int64_t wholeChunk = tdescTy.getShape().back();
587+
int64_t blockedChunk = targetShape->back();
588+
int64_t numInnerLoops = wholeChunk / blockedChunk;
589+
590+
for (auto mask : convertedMasks1D) {
591+
for (int64_t i = 0; i < numInnerLoops; ++i) {
592+
convertedMasks.push_back(mask);
593+
}
594+
}
595+
596+
std::swap((*targetShape)[0], (*targetShape)[1]);
597+
598+
} else {
599+
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
600+
convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
601+
}
602+
603+
SmallVector<Type> convertedValTypes =
604+
getUnrolledTypes(valueTy, *targetShape);
605+
SmallVector<Value> convertedValues =
606+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
577607

578608
for (size_t i = 0; i < convertedValues.size(); ++i) {
579609
Value v = convertedValues[i];
@@ -597,7 +627,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
597627
xegpu::TensorDescType tdescTy = op.getTensorDescType();
598628

599629
// check if the tensor descriptor type is a 1d vector type
600-
if (tdescTy.getRank() > 1)
630+
if (tdescTy.getRank() > 2)
601631
return failure();
602632

603633
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -611,17 +641,36 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
611641

612642
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
613643
VectorType offsetVecTy = offsetVec.getType();
614-
SmallVector<Type> convertedOffsetTypes =
615-
getUnrolledTypes(offsetVecTy, *targetShape);
616-
SmallVector<Value> convertedOffsetVec =
617-
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
618-
644+
SmallVector<Type> convertedOffsetTypes;
645+
SmallVector<Value> convertedOffsetVec;
619646
SmallVector<Value> newOps;
647+
648+
if (tdescTy.getRank() == 2) {
649+
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
650+
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
651+
SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
652+
653+
int64_t wholeChunk = tdescTy.getShape().back();
654+
int64_t blockedChunk = targetShape->back();
655+
int64_t numInnerLoops = wholeChunk / blockedChunk;
656+
657+
for (auto offset : convertedOffsetVec1D) {
658+
for (int64_t i = 0; i < numInnerLoops; ++i) {
659+
convertedOffsetVec.push_back(offset);
660+
}
661+
}
662+
663+
} else {
664+
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
665+
convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
666+
}
667+
620668
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
621669
auto newOp =
622670
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
623671
newOps.push_back(newOp);
624672
}
673+
625674
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
626675
rewriter.replaceOp(op, castOp);
627676
return success();

0 commit comments

Comments
 (0)