Skip to content

Commit 56c3c6c

Browse files
nbpatelaokblast
authored andcommitted
[MLIR][XeGPU]Enhance Pack/Unpack for XeGPUUnroll (llvm#163459)
This PR changes the pack/unpack method used for unrolling to allow for lower rank slice to be extracted and inserted from and to src vector by adding reshapes. It also removes leading unit dims from inst_data if there are any.
1 parent d003f15 commit 56c3c6c

File tree

4 files changed

+113
-13
lines changed

4 files changed

+113
-13
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
145145
xegpu::DistributeLayoutAttr layout =
146146
xegpu::getDistributeLayoutAttr(operandOrResult);
147147
if (layout && layout.isForSubgroup()) {
148-
if (!layout.getEffectiveInstDataAsInt().empty())
149-
return layout.getEffectiveInstDataAsInt();
148+
if (!layout.getEffectiveInstDataAsInt().empty()) {
149+
SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
150+
// Remove leading unit dimensions from inst_data
151+
// For example, if the inst_data is [1, 1, 32]
152+
// it will pass [32] as the unroll/blocking size.
153+
// Skip it for xegpu nd ops since it will be 2D
154+
// TODO: For vectors ops, experiment with the
155+
// upstream vector remove leading unit dims patterns,
156+
// populateCastAwayVectorLeadingOneDimPatterns.
157+
Operation *definingOp = value.getDefiningOp();
158+
bool skipLeadingUnitDimRemoval =
159+
definingOp &&
160+
(isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
161+
xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
162+
if (!skipLeadingUnitDimRemoval) {
163+
auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
164+
instData.erase(instData.begin(), it);
165+
}
166+
return instData;
167+
}
150168

151169
if (auto type = dyn_cast<ShapedType>(value.getType()))
152170
return llvm::to_vector(type.getShape());
@@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() {
354372
// To create a new attribute with a different chunk_size:
355373
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
356374
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
357-
358375
encoding = newEncoding;
359376
}
360377
}
@@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() {
363380
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
364381
tdescTy.getLayoutAttr().dropInstData());
365382
} else {
366-
newTy = type.clone(tileShape, elemTy);
383+
newTy = VectorType::get(tileShape, elemTy);
367384
}
368385

369386
if (returnSingleType)

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
6666
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
6767
Location loc, PatternRewriter &rewriter) const {
6868
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
69-
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
70-
"Expecting blockSize size to match the rank of destTy.");
7169
auto shape = vecTy.getShape();
7270
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
7371
}
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
9391
ArrayRef<int64_t> blockSize, Location loc,
9492
PatternRewriter &rewriter) const {
9593
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
96-
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
97-
"Expecting blockSize size to match the rank of src.");
9894
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
9995
blockSize);
10096
}
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
635631
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
636632
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
637633
Type elemTy = valueTy.getElementType();
638-
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
634+
VectorType newValueTy = VectorType::get(*targetShape, elemTy);
639635

640636
SmallVector<Type> convertedMaskTypes;
641637
SmallVector<Value> convertedMasks;

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
246246
if (!computeShapeRatio(srcShape, shape))
247247
return {value};
248248

249+
int64_t srcShapeRank = srcShape.size();
250+
int64_t targetShapeRank = shape.size();
251+
252+
SmallVector<int64_t> adjustedTargetShape(srcShape.size());
253+
int64_t rankDiff = srcShapeRank - targetShapeRank;
254+
std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
255+
1);
256+
std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
257+
249258
SmallVector<Value> result;
250-
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
259+
for (SmallVector<int64_t> offsets :
260+
StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
251261
SmallVector<int64_t> staticStrides(offsets.size(), 1);
252-
result.push_back(vector::ExtractStridedSliceOp::create(
253-
builder, loc, value, offsets, shape, staticStrides));
262+
Value slice = vector::ExtractStridedSliceOp::create(
263+
builder, loc, value, offsets, adjustedTargetShape, staticStrides);
264+
265+
// Reshape to remove leading unit dims if needed
266+
if (srcShapeRank > targetShapeRank) {
267+
auto targetTy = VectorType::get(shape, vecTy.getElementType());
268+
slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
269+
}
270+
result.push_back(slice);
254271
}
255272

256273
return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
274291

275292
for (auto [src, offsets] :
276293
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
277-
SmallVector<int64_t> staticStrides(offsets.size(), 1);
294+
SmallVector<int64_t> staticStrides(tileShape.size(), 1);
278295
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
279296
offsets, staticStrides);
280297
}

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,73 @@ gpu.module @test_kernel {
682682
gpu.return
683683
}
684684
}
685+
686+
// -----
687+
gpu.module @test_kernel {
688+
// CHECK-LABEL: remove_unit_dim_inst_data
689+
// CHECK-SAME: [[arg0:%.+]]: ui64
690+
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
691+
// CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
692+
// CHECK: [[cst_1:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
693+
// CHECK: [[cst_2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
694+
// CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
695+
// CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
696+
// CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
697+
// CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
698+
gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
699+
%cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
700+
[0, 8, 16, 24, 32, 40, 48, 56,
701+
64, 72, 80, 88, 96, 104, 112, 120,
702+
128, 136, 144, 152, 160, 168, 176, 184,
703+
192, 200, 208, 216, 224, 232, 240, 248]
704+
]]> : vector<1x1x32xindex>
705+
706+
%mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true> : vector<1x1x32xi1>
707+
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
708+
709+
gpu.return %ld : vector<1x1x32xf32>
710+
}
711+
}
712+
713+
// -----
714+
#l = #xegpu.layout<inst_data = [1, 16]>
715+
gpu.module @test_kernel {
716+
// CHECK-LABEL: load_store_nd_with_offsets
717+
// CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
718+
// CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
719+
// CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
720+
// CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
721+
// CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
722+
// CHECK: [[tdesc_b:%.+]] = xegpu.create_nd_tdesc [[arg1]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
723+
// CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
724+
// CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
725+
// CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
726+
// CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
727+
// CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
728+
// CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
729+
// CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
730+
// CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
731+
// CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
732+
// CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
733+
// CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
734+
// CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
735+
// CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
736+
// CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
737+
// CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
738+
// CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
739+
// CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
740+
gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
741+
%c0 = arith.constant 0 : index
742+
743+
%a_tdesc = xegpu.create_nd_tdesc %A : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
744+
%b_tdesc = xegpu.create_nd_tdesc %B : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
745+
%c_tdesc = xegpu.create_nd_tdesc %C : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
746+
747+
%a = xegpu.load_nd %a_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
748+
%b = xegpu.load_nd %b_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
749+
750+
%result = arith.addf %a, %b {layout_result_0 = #l} : vector<1x32xf32>
751+
xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
752+
gpu.return
753+
}
754+
}

0 commit comments

Comments
 (0)