Skip to content

Commit a22b251

Browse files
committed
[mlir][XeGPU][VectorToXeGPU] Propagate vector layouts to xegpu ops
Signed-off-by: dchigarev <[email protected]>
1 parent 0c2e900 commit a22b251

File tree

3 files changed

+310
-20
lines changed

3 files changed

+310
-20
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -374,22 +374,29 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
374374
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
375375
}
376376
Value indices = gatScatOp.getIndices();
377+
auto indicesLayout = mlir::xegpu::getDistributeLayoutAttr(indices);
377378
VectorType vecType = cast<VectorType>(indices.getType());
378379

379-
Value strideVector =
380-
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
381-
.getResult();
382-
Value stridedIndices =
383-
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
384-
385-
Value baseVector =
386-
vector::BroadcastOp::create(
387-
rewriter, loc,
388-
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
389-
baseOffset)
390-
.getResult();
391-
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
392-
.getResult();
380+
auto strideVector =
381+
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back());
382+
mlir::xegpu::setDistributeLayoutAttr(strideVector->getOpResult(0),
383+
indicesLayout);
384+
385+
auto stridedIndices =
386+
arith::MulIOp::create(rewriter, loc, strideVector.getResult(), indices);
387+
mlir::xegpu::setDistributeLayoutAttr(stridedIndices->getOpResult(0),
388+
indicesLayout);
389+
390+
auto baseVector = vector::BroadcastOp::create(
391+
rewriter, loc,
392+
VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset);
393+
mlir::xegpu::setDistributeLayoutAttr(baseVector->getOpResult(0),
394+
indicesLayout);
395+
396+
auto result = arith::AddIOp::create(rewriter, loc, baseVector.getResult(),
397+
stridedIndices.getResult());
398+
mlir::xegpu::setDistributeLayoutAttr(result->getOpResult(0), indicesLayout);
399+
return result.getResult();
393400
}
394401

395402
template <
@@ -616,16 +623,34 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
616623
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
617624
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
618625

626+
auto numOffsets = gatherOp.getOffsets().size();
627+
auto layoutRes = mlir::xegpu::getDistributeLayoutAttr(gatherOp.getResult());
628+
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
629+
gatherOp->getOpOperand(numOffsets + 1));
630+
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
631+
gatherOp->getOpOperand(numOffsets + 2));
632+
auto layoutPassThru = mlir::xegpu::getDistributeLayoutAttr(
633+
gatherOp->getOpOperand(numOffsets + 3));
619634
auto xeGatherOp = xegpu::LoadGatherOp::create(
620635
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
621636
/*chunk_size=*/IntegerAttr{},
622637
/*l1_hint=*/xegpu::CachePolicyAttr{},
623638
/*l2_hint=*/xegpu::CachePolicyAttr{},
624639
/*l3_hint=*/xegpu::CachePolicyAttr{});
640+
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
641+
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1),
642+
layoutIndices);
643+
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(2),
644+
layoutMask);
625645

626646
auto selectOp =
627647
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
628648
xeGatherOp.getResult(), gatherOp.getPassThru());
649+
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(0), layoutMask);
650+
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(2),
651+
layoutPassThru);
652+
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);
653+
629654
rewriter.replaceOp(gatherOp, selectOp.getResult());
630655
return success();
631656
}
@@ -650,12 +675,24 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
650675
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
651676
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
652677

653-
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
654-
flatMemref, localOffsets, scatterOp.getMask(),
655-
/*chunk_size=*/IntegerAttr{},
656-
/*l1_hint=*/xegpu::CachePolicyAttr{},
657-
/*l2_hint=*/xegpu::CachePolicyAttr{},
658-
/*l3_hint=*/xegpu::CachePolicyAttr{});
678+
auto numOffsets = scatterOp.getOffsets().size();
679+
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
680+
scatterOp->getOpOperand(numOffsets + 1));
681+
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
682+
scatterOp->getOpOperand(numOffsets + 2));
683+
auto layoutVal = mlir::xegpu::getDistributeLayoutAttr(
684+
scatterOp->getOpOperand(numOffsets + 3));
685+
auto storeOp = xegpu::StoreScatterOp::create(
686+
rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets,
687+
scatterOp.getMask(),
688+
/*chunk_size=*/IntegerAttr{},
689+
/*l1_hint=*/xegpu::CachePolicyAttr{},
690+
/*l2_hint=*/xegpu::CachePolicyAttr{},
691+
/*l3_hint=*/xegpu::CachePolicyAttr{});
692+
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
693+
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2),
694+
layoutIndices);
695+
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(3), layoutMask);
659696
rewriter.eraseOp(scatterOp);
660697
return success();
661698
}

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,144 @@ gpu.func @non_unit_inner_stride_3D(
249249
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
250250
// CHECK: gpu.return %[[RES]] : vector<8xf32>
251251
}
252+
253+
// -----
254+
255+
gpu.module @xevm_module {
256+
gpu.func @load_dynamic_layout_operands(%source: memref<?x?xf32>,
257+
%off0: index, %off1: index,
258+
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
259+
%pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
260+
%res = vector.gather %source[%off0, %off1][%indices], %mask,
261+
%pass_thru {
262+
layout_result_0 = #xegpu.layout<sg_layout = [0]>,
263+
layout_operand_3 = #xegpu.layout<sg_layout = [1]>,
264+
layout_operand_4 = #xegpu.layout<sg_layout = [2]>,
265+
layout_operand_5 = #xegpu.layout<sg_layout = [3]>
266+
} : memref<?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
267+
gpu.return %res : vector<8x16xf32>
268+
}
269+
// CHECK-LABEL: @load_dynamic_layout_operands(
270+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?xf32>,
271+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
272+
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>, %[[PASS:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
273+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
274+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
275+
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
276+
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [1]>, layout_operand_2 = #xegpu.layout<sg_layout = [2]>,
277+
// CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [0]>}
278+
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
279+
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [2]>,
280+
// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [3]>,
281+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [0]>} : vector<8x16xi1>, vector<8x16xf32>
282+
}
283+
284+
// -----
285+
286+
gpu.module @xevm_module {
287+
gpu.func @load_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
288+
%off0: index, %off1: index, %off2: index,
289+
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
290+
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
291+
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
292+
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
293+
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
294+
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
295+
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
296+
297+
%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
298+
%pass_thru {
299+
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
300+
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
301+
} : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
302+
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
303+
gpu.return %res2 : vector<8x16xf32>
304+
}
305+
// CHECK-LABEL: @load_dynamic_layout_mixed(
306+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
307+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
308+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
309+
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
310+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
311+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
312+
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
313+
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
314+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
315+
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
316+
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
317+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
318+
}
319+
320+
321+
// -----
322+
323+
gpu.module @xevm_module {
324+
gpu.func @load_static_layout_mixed(%source: memref<8x16x32xf32>,
325+
%off0: index, %off1: index, %off2: index,
326+
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
327+
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
328+
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
329+
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
330+
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
331+
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
332+
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
333+
334+
%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
335+
%pass_thru {
336+
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
337+
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
338+
} : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
339+
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
340+
gpu.return %res2 : vector<8x16xf32>
341+
}
342+
// CHECK-LABEL: @load_static_layout_mixed(
343+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
344+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
345+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
346+
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
347+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
348+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
349+
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
350+
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
351+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
352+
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
353+
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
354+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
355+
}
356+
357+
// -----
358+
359+
gpu.module @xevm_module {
360+
gpu.func @load_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
361+
%off0: index, %off1: index, %off2: index,
362+
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
363+
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
364+
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
365+
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
366+
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
367+
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
368+
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
369+
370+
%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
371+
%pass_thru {
372+
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
373+
layout_operand_4 = #xegpu.layout<sg_layout = [99]>, // overriding %2's layout
374+
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
375+
} : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
376+
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
377+
gpu.return %res2 : vector<8x16xf32>
378+
}
379+
// CHECK-LABEL: @load_dynamic_layout_mixed_override(
380+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
381+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
382+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
383+
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
384+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
385+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
386+
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
387+
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [99]>, layout_operand_2 = #xegpu.layout<sg_layout = [7]>
388+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
389+
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
390+
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
391+
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
392+
}

0 commit comments

Comments
 (0)