Skip to content

Commit e6a814e

Browse files
committed
Address comments
1 parent a4ebc37 commit e6a814e

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,17 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
148148
if (!layout.getEffectiveInstDataAsInt().empty()) {
149149
SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
150150
// Remove leading unit dimensions from inst_data
151+
// For example, if the inst_data is [1, 1, 32]
152+
// it will pass [32] as the unroll/blocking size.
151153
// Skip it for xegpu nd ops since it will be 2D
152154
Operation *definingOp = value.getDefiningOp();
153155
bool skipLeadingUnitDimRemoval =
154156
definingOp &&
155157
(isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
156158
xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
157159
if (!skipLeadingUnitDimRemoval) {
158-
while (!instData.empty() && instData.front() == 1)
159-
instData.erase(instData.begin());
160+
auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
161+
instData.erase(instData.begin(), it);
160162
}
161163
return instData;
162164
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,6 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
255255
1);
256256
std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
257257

258-
int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
259-
260258
SmallVector<Value> result;
261259
for (SmallVector<int64_t> offsets :
262260
StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
@@ -265,7 +263,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
265263
builder, loc, value, offsets, adjustedTargetShape, staticStrides);
266264

267265
// Reshape to remove leading unit dims if needed
268-
if (adjustedTargetShapeRank > targetShapeRank) {
266+
if (srcShapeRank > targetShapeRank) {
269267
auto targetTy = VectorType::get(shape, vecTy.getElementType());
270268
slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
271269
}

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ gpu.module @test_kernel {
685685

686686
// -----
687687
gpu.module @test_kernel {
688-
// CHECK-LABEL: load_gather
688+
// CHECK-LABEL: remove_unit_dim_inst_data
689689
// CHECK-SAME: [[arg0:%.+]]: ui64
690690
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
691691
// CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
@@ -695,7 +695,7 @@ gpu.module @test_kernel {
695695
// CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
696696
// CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
697697
// CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
698-
gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
698+
gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
699699
%cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
700700
[0, 8, 16, 24, 32, 40, 48, 56,
701701
64, 72, 80, 88, 96, 104, 112, 120,

0 commit comments

Comments
 (0)