Skip to content

Commit 23f9557

Browse files
committed
Enhance Pack/Unpack for XeGPUUnroll
1 parent 27d8441 commit 23f9557

File tree

4 files changed

+75
-12
lines changed

4 files changed

+75
-12
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
145145
xegpu::DistributeLayoutAttr layout =
146146
xegpu::getDistributeLayoutAttr(operandOrResult);
147147
if (layout && layout.isForSubgroup()) {
148-
if (!layout.getEffectiveInstDataAsInt().empty())
149-
return layout.getEffectiveInstDataAsInt();
148+
if (!layout.getEffectiveInstDataAsInt().empty()) {
149+
SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
150+
// Remove all leading unit dimensions from inst_data
151+
while (!instData.empty() && instData.front() == 1)
152+
instData.erase(instData.begin());
153+
return instData;
154+
}
150155

151156
if (auto type = dyn_cast<ShapedType>(value.getType()))
152157
return llvm::to_vector(type.getShape());
@@ -363,7 +368,7 @@ void XeGPUBlockingPass::runOnOperation() {
363368
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
364369
tdescTy.getLayoutAttr().dropInstData());
365370
} else {
366-
newTy = type.clone(tileShape, elemTy);
371+
newTy = VectorType::get(tileShape, elemTy);
367372
}
368373

369374
if (returnSingleType)

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
6666
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
6767
Location loc, PatternRewriter &rewriter) const {
6868
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
69-
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
70-
"Expecting blockSize size to match the rank of destTy.");
7169
auto shape = vecTy.getShape();
7270
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
7371
}
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
9391
ArrayRef<int64_t> blockSize, Location loc,
9492
PatternRewriter &rewriter) const {
9593
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
96-
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
97-
"Expecting blockSize size to match the rank of src.");
9894
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
9995
blockSize);
10096
}
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
635631
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
636632
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
637633
Type elemTy = valueTy.getElementType();
638-
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
634+
VectorType newValueTy = VectorType::get(*targetShape, elemTy);
639635

640636
SmallVector<Type> convertedMaskTypes;
641637
SmallVector<Value> convertedMasks;

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,30 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
246246
if (!computeShapeRatio(srcShape, shape))
247247
return {value};
248248

249+
int64_t srcShapeRank = srcShape.size();
250+
int64_t targetShapeRank = shape.size();
251+
252+
SmallVector<int64_t> adjustedTargetShape(srcShape.size());
253+
int64_t rankDiff = srcShapeRank - targetShapeRank;
254+
std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
255+
1);
256+
std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
257+
258+
int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
259+
249260
SmallVector<Value> result;
250-
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
261+
for (SmallVector<int64_t> offsets :
262+
StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
251263
SmallVector<int64_t> staticStrides(offsets.size(), 1);
252-
result.push_back(vector::ExtractStridedSliceOp::create(
253-
builder, loc, value, offsets, shape, staticStrides));
264+
Value slice = vector::ExtractStridedSliceOp::create(
265+
builder, loc, value, offsets, adjustedTargetShape, staticStrides);
266+
267+
// Reshape to remove leading unit dims if needed
268+
if (adjustedTargetShapeRank > targetShapeRank) {
269+
auto targetTy = VectorType::get(shape, vecTy.getElementType());
270+
slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
271+
}
272+
result.push_back(slice);
254273
}
255274

256275
return result;
@@ -274,7 +293,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
274293

275294
for (auto [src, offsets] :
276295
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
277-
SmallVector<int64_t> staticStrides(offsets.size(), 1);
296+
SmallVector<int64_t> staticStrides(tileShape.size(), 1);
278297
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
279298
offsets, staticStrides);
280299
}

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,46 @@ gpu.module @test_kernel {
682682
gpu.return
683683
}
684684
}
685+
686+
// -----
687+
gpu.module @test_kernel {
688+
// CHECK-LABEL: load_gather
689+
// CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
690+
gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
691+
%cst = arith.constant dense<[[
692+
[0, 8, 16, 24, 32, 40, 48, 56,
693+
64, 72, 80, 88, 96, 104, 112, 120,
694+
128, 136, 144, 152, 160, 168, 176, 184,
695+
192, 200, 208, 216, 224, 232, 240, 248]
696+
]]> : vector<1x1x32xindex>
697+
698+
%mask = arith.constant dense<true> : vector<1x1x32xi1>
699+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
700+
701+
gpu.return %ld : vector<1x1x32xf32>
702+
}
703+
}
704+
705+
// -----
706+
gpu.module @test_kernel {
707+
// CHECK-LABEL: store_scatter
708+
// CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
709+
gpu.func @store_scatter(%src: ui64) {
710+
%cst = arith.constant dense<[[
711+
[0, 8, 16, 24, 32, 40, 48, 56,
712+
64, 72, 80, 88, 96, 104, 112, 120,
713+
128, 136, 144, 152, 160, 168, 176, 184,
714+
192, 200, 208, 216, 224, 232, 240, 248]
715+
]]> : vector<1x1x32xindex>
716+
717+
%mask = arith.constant dense<true> : vector<1x1x32xi1>
718+
719+
%st_vec = arith.constant dense<1023.0>: vector<1x1x32xf32>
720+
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
721+
layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
722+
layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
723+
l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
724+
725+
gpu.return
726+
}
727+
}

0 commit comments

Comments
 (0)