Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 136 additions & 42 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];

TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
if (originalChunkSize > 1)
targetIndiceShape.pop_back();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
getUnrolledTypes(indiceVecTy, targetIndiceShape);
SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
newOps.push_back(newOp);

// more indices is need when chunkSize > 1. Since a big load from one
// address could be break into multiple small loads.
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto [indice, indiceType] :
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
// Compute the offset
Value inc = rewriter.create<arith::ConstantIndexOp>(
loc, i * blockedChunkSize);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
Value offsetIndice =
rewriter.create<arith::AddIOp>(loc, indice, incVec);

auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), offsetIndice);

newOps.push_back(newOp);
}
}
} else {
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
}
}

Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
Expand All @@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<int64_t> targetMaskShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

Expand All @@ -462,10 +492,32 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;

if (originalChunkSize > 1) {
targetMaskShape.pop_back();
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
// This is to handle the transpose effect when chunkSize > 1.
if (targetShape && targetShape->size() > 1) {
std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
}
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
loc, rewriter);
}

SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
Expand All @@ -476,7 +528,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}

Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);

rewriter.replaceOp(op, castOp);
return success();
}
Expand All @@ -490,7 +541,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (tdescTy.getRank() > 2)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand Down Expand Up @@ -519,30 +570,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);

SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;

if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);

for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
// This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);

} else {
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
}

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
Expand All @@ -565,8 +637,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (tdescTy.getRank() > 2)
return failure();

if (!tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand All @@ -580,12 +654,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {

TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
int64_t originalChunkSize = tdescTy.getChunkSize();
if (originalChunkSize > 1) {
SmallVector<int64_t> shape1D(targetShape->begin(),
targetShape->end() - 1);
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
SmallVector<Value> convertedOffsetVec1D =
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);

int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto offset : convertedOffsetVec1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedOffsetVec.push_back(offset);
}
}

} else {
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
}

for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
Expand Down
Loading
Loading