Skip to content

Commit 2a43ee6

Browse files
committed
only set layouts for anchor ops
Signed-off-by: dchigarev <[email protected]>
1 parent 3afe5d5 commit 2a43ee6

File tree

3 files changed

+22
-84
lines changed

3 files changed

+22
-84
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -389,27 +389,22 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
389389
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
390390
}
391391
Value indices = gatScatOp.getIndices();
392-
// Extract indices layout and propagate it to all 'vector' ops created here
393-
auto indicesLayout = xegpu::getDistributeLayoutAttr(indices);
394392
VectorType vecType = cast<VectorType>(indices.getType());
395393

396-
auto strideVector =
397-
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back());
398-
xegpu::setDistributeLayoutAttr(strideVector->getOpResult(0), indicesLayout);
399-
400-
auto stridedIndices =
401-
arith::MulIOp::create(rewriter, loc, strideVector.getResult(), indices);
402-
xegpu::setDistributeLayoutAttr(stridedIndices->getOpResult(0), indicesLayout);
403-
404-
auto baseVector = vector::BroadcastOp::create(
405-
rewriter, loc,
406-
VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset);
407-
xegpu::setDistributeLayoutAttr(baseVector->getOpResult(0), indicesLayout);
408-
409-
auto result = arith::AddIOp::create(rewriter, loc, baseVector.getResult(),
410-
stridedIndices.getResult());
411-
xegpu::setDistributeLayoutAttr(result->getOpResult(0), indicesLayout);
412-
return result.getResult();
394+
Value strideVector =
395+
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
396+
.getResult();
397+
Value stridedIndices =
398+
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
399+
400+
Value baseVector =
401+
vector::BroadcastOp::create(
402+
rewriter, loc,
403+
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
404+
baseOffset)
405+
.getResult();
406+
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
407+
.getResult();
413408
}
414409

415410
template <
@@ -659,7 +654,6 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
659654
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
660655
xeGatherOp.getResult(), gatherOp.getPassThru());
661656
xegpu::setDistributeLayoutAttr(selectOp.getConditionMutable(), layoutMask);
662-
xegpu::setDistributeLayoutAttr(selectOp.getTrueValueMutable(), layoutRes);
663657
xegpu::setDistributeLayoutAttr(selectOp.getFalseValueMutable(),
664658
layoutPassThru);
665659
xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,7 @@ gpu.func @load_dynamic_layout_operands(%source: memref<?x?xf32>,
268268
gpu.return %res : vector<8x16xf32>
269269
}
270270
// CHECK-LABEL: @load_dynamic_layout_operands(
271-
// CHECK-SAME: %[[SRC:.+]]: memref<?x?xf32>,
272-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
273-
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>, %[[PASS:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
274-
// %indices producer doesn't have a layout, so as 'broadcast/add' ops computing linear index.
275-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
276-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
277-
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
271+
// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}}
278272
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [1]>, layout_operand_2 = #xegpu.layout<sg_layout = [2]>,
279273
// CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [0]>}
280274
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
@@ -305,14 +299,7 @@ gpu.func @load_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
305299
gpu.return %res2 : vector<8x16xf32>
306300
}
307301
// CHECK-LABEL: @load_dynamic_layout_mixed(
308-
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
309-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
310-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
311-
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
312-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
313-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
314-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
315-
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
302+
// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}}
316303
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
317304
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
318305
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
@@ -343,14 +330,7 @@ gpu.func @load_static_layout_mixed(%source: memref<8x16x32xf32>,
343330
gpu.return %res2 : vector<8x16xf32>
344331
}
345332
// CHECK-LABEL: @load_static_layout_mixed(
346-
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
347-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
348-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
349-
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
350-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
351-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
352-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
353-
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
333+
// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}}
354334
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
355335
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
356336
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
@@ -381,15 +361,7 @@ gpu.func @load_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
381361
gpu.return %res2 : vector<8x16xf32>
382362
}
383363
// CHECK-LABEL: @load_dynamic_layout_mixed_override(
384-
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
385-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
386-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
387-
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
388-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2)
389-
// and not it's overriden version from the scatter_op (sg_layout = [99])
390-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
391-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
392-
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
364+
// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}}
393365
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [99]>, layout_operand_2 = #xegpu.layout<sg_layout = [7]>
394366
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
395367
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}

mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,7 @@ gpu.func @store_dynamic_layout_operands(%vec: vector<8x16xf32>, %source: memref<
219219
gpu.return
220220
}
221221
// CHECK-LABEL: @store_dynamic_layout_operands(
222-
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x?xf32>,
223-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
224-
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
225-
// %indices producer doesn't have a layout, so as 'broadcast/add' ops computing linear index.
226-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
227-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
228-
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
222+
// CHECK: xegpu.store {{[^{]*}}
229223
// CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [2]>, layout_operand_2 = #xegpu.layout<sg_layout = [0]>, layout_operand_3 = #xegpu.layout<sg_layout = [1]>}
230224
}
231225

@@ -248,14 +242,7 @@ gpu.func @store_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
248242
gpu.return
249243
}
250244
// CHECK-LABEL: @store_dynamic_layout_mixed(
251-
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
252-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
253-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
254-
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
255-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
256-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
257-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
258-
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
245+
// CHECK: xegpu.store {{[^{]*}}
259246
// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout<sg_layout = [6]>}
260247
}
261248

@@ -278,14 +265,7 @@ gpu.func @store_static_layout_mixed(%source: memref<8x16x32xf32>,
278265
gpu.return
279266
}
280267
// CHECK-LABEL: @store_static_layout_mixed(
281-
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
282-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
283-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
284-
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
285-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
286-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
287-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
288-
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
268+
// CHECK: xegpu.store {{[^{]*}}
289269
// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout<sg_layout = [6]>}
290270
}
291271

@@ -309,15 +289,7 @@ gpu.func @store_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
309289
gpu.return
310290
}
311291
// CHECK-LABEL: @store_dynamic_layout_mixed_override(
312-
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
313-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
314-
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
315-
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
316-
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2)
317-
// and not it's overriden version from the scatter_op (sg_layout = [99])
318-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
319-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
320-
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
292+
// CHECK: xegpu.store {{[^{]*}}
321293
// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [99]>,
322294
// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout<sg_layout = [6]>}
323295
}

0 commit comments

Comments
 (0)