Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}

// Extract cache hints from the op attributes if available.
static void getOpCacheHints(Operation *op,
SmallVector<xegpu::CachePolicyAttr, 3> &hints) {
assert(hints.size() == 3 &&
"Expecting a vector of size 3 for l1, l2, l3 hints.");
// get l1, l2, l3 hints from attributes if available.
if (auto l1Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l1_hint"))
hints[0] = l1Attr;
if (auto l2Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l2_hint"))
hints[1] = l2Attr;
if (auto l3Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l3_hint"))
hints[2] = l3Attr;
}

static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Expand Down Expand Up @@ -374,22 +388,30 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
Value indices = gatScatOp.getIndices();
// Extract indices layout and propagate it to all 'vector' ops created here
auto indicesLayout = mlir::xegpu::getDistributeLayoutAttr(indices);
VectorType vecType = cast<VectorType>(indices.getType());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this section we compute 'flat' indices for the xegpu.load/store op using vector's %indices operand like this:

%base_offset = vector.broadcast index -> vector<...>
%last_stride = vector.broadcast index -> vector<...>
%flat_indices = %indices * %last_stride + %base_offset

We want all the vector broadcast/add/mul operations that are generated here to have layout_result_0 to be equal to the %indices layout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the propagation pass to properly set up the layout attributes for non-anchor ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether I understand the idea of the -xegpu-propagate-layout pass correctly.

Based on your comment and on my initial impression (and on this docstring), the pass should propagate user's layouts or set a default one, for example:

// if I set a custom layout for `xegpu.store`, I would expect the same layout
// to propagate to the result of `arith.select` (producer of the operand_0)
%res = arith.select %mask, %other, %3 : vector<16x16xi1>, vector<16x16xf16>
xegpu.store %res, %src[%offset], %1 {
    layout_operand_0 = #xegpu.layout<lane_layout = [4, 4], lane_data = [1, 2]>}
    : vector<16x16xf16>, memref<256xf16>, vector<16x16xindex>, vector<16x16xi1>

// however in reality it applies "defaultSIMTLayout" ignoring my custom layout:
%res = arith.select %mask, %other, %3 {
    layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
    : vector<16x16xi1>, vector<16x16xf16>
xegpu.store %res, %src[%offset], %1  {
    layout_operand_0 = #xegpu.layout<lane_layout = [4, 4], lane_data = [1, 2]>} 
    : vector<16x16xf16>, memref<256xf16>, vector<16x16xindex>, vector<16x16xi1>

Based on the logic from the pass, it always applies default-simt layout in case of results/producers for store_scatter (and probably for load_gather as well). Is this intended behavior or under "we should use the propagation pass" you meant, that we should improve the pass to also consider custom user layouts?

cc @charithaintc

Copy link
Contributor

@charithaintc charithaintc Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dchigarev, the pass is still WIP. ideally it should respect user provided custom layouts. but this require handling conflicts when a user assigns some arbitrary layout that HW can not support. because the conflict handling part is not implemented yet we just decided to ignore user's layouts for now. Did I answer your question?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I answer your question?

yes, thank you, just making sure that we're aware of that after merging just this PR we won't fully cover the vector-op layouts propagation, since arith.broadcast/mul/etc would have default layouts, and not the user-set-layouts from the lowered op (so the fuser team won't be satisfied).

Does it make sense leave this temporary logic that manually sets vector-op layouts to all the ops generated by vector-to-xegpu pass (without relying on the propagate-layout pass), or we should avoid this logic and start working on the changes in the propagate-layout pass right away?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% up to date on the requirement here. but as a rule of thumb, I would just focus on setting the layouts properly for xegpu operations that have some layout information in its definition (loads/stores/createNd etc). These layouts will not be lost during transformations.

For all other arith ops that are generated as a part of this lowering, we can safely ignore the layouts because layout propagation should be able to take care of them (no need to replicate logic). But if certain op require a non-default layout for some reason then it should carry that information. But I highly doubt this will be the case.


Value strideVector =
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
.getResult();
Value stridedIndices =
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();

Value baseVector =
vector::BroadcastOp::create(
rewriter, loc,
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
baseOffset)
.getResult();
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
.getResult();
auto strideVector =
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back());
mlir::xegpu::setDistributeLayoutAttr(strideVector->getOpResult(0),
indicesLayout);

auto stridedIndices =
arith::MulIOp::create(rewriter, loc, strideVector.getResult(), indices);
mlir::xegpu::setDistributeLayoutAttr(stridedIndices->getOpResult(0),
indicesLayout);

auto baseVector = vector::BroadcastOp::create(
rewriter, loc,
VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset);
mlir::xegpu::setDistributeLayoutAttr(baseVector->getOpResult(0),
indicesLayout);

auto result = arith::AddIOp::create(rewriter, loc, baseVector.getResult(),
stridedIndices.getResult());
mlir::xegpu::setDistributeLayoutAttr(result->getOpResult(0), indicesLayout);
return result.getResult();
}

template <
Expand Down Expand Up @@ -616,16 +638,39 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);

auto numOffsets = gatherOp.getOffsets().size();
auto layoutRes = mlir::xegpu::getDistributeLayoutAttr(gatherOp.getResult());
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
gatherOp->getOpOperand(numOffsets + 1));
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
gatherOp->getOpOperand(numOffsets + 2));
auto layoutPassThru = mlir::xegpu::getDistributeLayoutAttr(
gatherOp->getOpOperand(numOffsets + 3));

SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
xegpu::CachePolicyAttr{},
xegpu::CachePolicyAttr{}};
getOpCacheHints(gatherOp, cacheHints);
auto xeGatherOp = xegpu::LoadGatherOp::create(
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l1_hint=*/cacheHints[0],
/*l2_hint=*/cacheHints[1],
/*l3_hint=*/cacheHints[2]);
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1),
layoutIndices);
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(2),
layoutMask);

auto selectOp =
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
xeGatherOp.getResult(), gatherOp.getPassThru());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double-check, the layout isn't assigned to the second operand (LoadGather) as it's already in the producer's result.
I assume it's left to the propagation to fill the gap? Any drawbacks in assigning layout in both places?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double-check, the layout isn't assigned to the second operand (LoadGather) as it's already in the producer's result

Right, I thought that's enough. There seems to be no drawbacks though from assigning layout in both places. Applied layout in both places in the last commit just in case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting the operand layout is done at the client of layout propagation result. So I think there is no need to update layout_operand[_] here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case? Say, the value to store comes as a function's argument (our func is not inlined yet) and it's impossible to determine the producer but it's possible to determine the layout from {layout_operand_i} attribute. Should we support such cases? Are there any negative side-effects from setting a layout_operand_i even if we can access the layout from layout_result_i of the value producer?

Copy link
Contributor

@Jianhui-Li Jianhui-Li Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should only set the layout for anchor op, and leaving the propogation to set layout for operands of neighbour ops.
In case the value/operands are from function arguments, the propogation should propogate through functions. The reality is that the use case we have are mostly inlined, so we didnt' run into any function calls within device kernel.

And we don't need to have these temporary layout in the tests.

mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(0), layoutMask);
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(2),
layoutPassThru);
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);

rewriter.replaceOp(gatherOp, selectOp.getResult());
return success();
}
Expand All @@ -650,12 +695,28 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);

xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
flatMemref, localOffsets, scatterOp.getMask(),
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
auto numOffsets = scatterOp.getOffsets().size();
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
scatterOp->getOpOperand(numOffsets + 1));
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
scatterOp->getOpOperand(numOffsets + 2));
auto layoutVal = mlir::xegpu::getDistributeLayoutAttr(
scatterOp->getOpOperand(numOffsets + 3));
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
xegpu::CachePolicyAttr{},
xegpu::CachePolicyAttr{}};
getOpCacheHints(scatterOp, cacheHints);
auto storeOp = xegpu::StoreScatterOp::create(
rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets,
scatterOp.getMask(),
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/cacheHints[0],
/*l2_hint=*/cacheHints[1],
/*l3_hint=*/cacheHints[2]);
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2),
layoutIndices);
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(3), layoutMask);
rewriter.eraseOp(scatterOp);
return success();
}
Expand Down
185 changes: 185 additions & 0 deletions mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,188 @@ gpu.func @non_unit_inner_stride_3D(
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
// CHECK: gpu.return %[[RES]] : vector<8xf32>
}

// -----

gpu.module @xevm_module {
// Layouts are only specified for the gather op itself.
gpu.func @load_dynamic_layout_operands(%source: memref<?x?xf32>,
%off0: index, %off1: index,
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
%pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
%res = vector.gather %source[%off0, %off1][%indices], %mask,
%pass_thru {
layout_result_0 = #xegpu.layout<sg_layout = [0]>,
layout_operand_3 = #xegpu.layout<sg_layout = [1]>,
layout_operand_4 = #xegpu.layout<sg_layout = [2]>,
layout_operand_5 = #xegpu.layout<sg_layout = [3]>
} : memref<?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
gpu.return %res : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_layout_operands(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?xf32>,
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>, %[[PASS:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
// %indices producer doesn't have a layout, so as 'broadcast/add' ops computing linear index.
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [1]>, layout_operand_2 = #xegpu.layout<sg_layout = [2]>,
// CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [0]>}
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [2]>,
// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [3]>,
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [0]>} : vector<8x16xi1>, vector<8x16xf32>
}

// -----

gpu.module @xevm_module {
gpu.func @load_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
%off0: index, %off1: index, %off2: index,
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>

%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
%pass_thru {
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
} : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
gpu.return %res2 : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_layout_mixed(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to my previous comment operand layouts are not needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
}


// -----

gpu.module @xevm_module {
gpu.func @load_static_layout_mixed(%source: memref<8x16x32xf32>,
%off0: index, %off1: index, %off2: index,
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>

%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
%pass_thru {
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
} : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
gpu.return %res2 : vector<8x16xf32>
}
// CHECK-LABEL: @load_static_layout_mixed(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [7]>
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
}

// -----

gpu.module @xevm_module {
gpu.func @load_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
%off0: index, %off1: index, %off2: index,
%mask: vector<8x16xi1>) -> vector<8x16xf32> {
%pass_thru = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1]>} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
%0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout<sg_layout = [3]>} : vector<8x1xindex> to vector<8x16xindex>
%1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout<sg_layout = [4]>} : vector<1x16xindex> to vector<8x16xindex>
%2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>

%res = vector.gather %source[%off0, %off1, %off2][%2], %mask,
%pass_thru {
layout_result_0 = #xegpu.layout<sg_layout = [6]>,
layout_operand_4 = #xegpu.layout<sg_layout = [99]>, // overriding %2's layout
layout_operand_5 = #xegpu.layout<sg_layout = [7]>
} : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
%res2 = arith.addf %res, %pass_thru : vector<8x16xf32>
gpu.return %res2 : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_layout_mixed_override(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2)
// and not it's overriden version from the scatter_op (sg_layout = [99])
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
// CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [99]>, layout_operand_2 = #xegpu.layout<sg_layout = [7]>
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>}
// CHECK: %[[RES:.+]] = arith.select {{[^{]*}}
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
}

// -----

gpu.module @xevm_module {
gpu.func @load_with_cache_hints(%source: memref<8x16x32xf32>,
%off1: index, %off2: index, %off3: index,
%indices: vector<8xindex>, %mask: vector<8xi1>,
%pass_thru: vector<8xf32>) -> vector<8xf32> {
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
%pass_thru {
l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<streaming>
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
gpu.return %0 : vector<8xf32>
}
// CHECK-LABEL: @load_with_cache_hints(
// CHECK: xegpu.load {{[^<]*}}
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<streaming>}>
}

// -----

gpu.module @xevm_module {
gpu.func @load_with_partial_cache_hints(%source: memref<8x16x32xf32>,
%off1: index, %off2: index, %off3: index,
%indices: vector<8xindex>, %mask: vector<8xi1>,
%pass_thru: vector<8xf32>) -> vector<8xf32> {
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
%pass_thru {
l1_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<streaming>
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
gpu.return %0 : vector<8xf32>
}
// CHECK-LABEL: @load_with_partial_cache_hints(
// CHECK: xegpu.load {{[^<]*}}
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
}
Loading