Skip to content

Commit 3f0bb99

Browse files
akroviakovcharithaintc
authored andcommitted
Automerge: [MLIR][XeGPU] Reapply attempt for "Scattered ops sg-to-wi distribution #154949" (#156924)
This PR is a reapply of llvm/llvm-project#154949, which failed one of sanitizer checks. The issue was querying the `warpOp` results in `LoadDistribution` after calling `moveRegionToNewWarpOpAndAppendReturns()`, which resulted in use after free. This PR solves the issue by moving the op query before the call and is otherwise identical to the one linked above. --------- Co-authored-by: Charitha Saumya <[email protected]>
2 parents c2c175a + 6c6afdd commit 3f0bb99

File tree

4 files changed

+315
-24
lines changed

4 files changed

+315
-24
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
194194
}
195195

196196
/// Helper to get the default layout for a vector type.
197-
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
197+
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
198+
bool isScattered = false) {
198199
// Expecting a 1D or 2D vector.
199200
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
200201
"Expected 1D or 2D vector.");
@@ -207,14 +208,23 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
207208
// Packing factor is determined by the element type bitwidth.
208209
int packingFactor = 1;
209210
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
211+
if (isScattered) {
212+
packingFactor =
213+
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
214+
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
215+
: 1;
216+
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
217+
LaneData({1, packingFactor}));
218+
}
210219
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
211220
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
212221
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
213222
LaneData({1, packingFactor}));
214223
}
215224

216225
/// Helper to get the default layout for a vector type.
217-
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
226+
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
227+
bool isScattered = false) {
218228
// Expecting a 1D or 2D vector.
219229
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
220230
"Expected 1D or 2D TensorDesc.");
@@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
227237
// Packing factor is determined by the element type bitwidth.
228238
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
229239

230-
if (tdescTy.isScattered()) {
240+
if (isScattered) {
231241
int packingFactor =
232242
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
233243
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -541,21 +551,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
541551
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
542552
}
543553

544-
/// Propagate the layout of the result to the tensor descriptor and mask
554+
/// Propagate the layout of the result to the tensor descriptor, mask and offset
545555
/// operands in LoadGatherOp.
546556
void LayoutInfoPropagation::visitLoadGatherOp(
547557
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
548558
ArrayRef<const LayoutInfoLattice *> results) {
549-
// The layout is strictly determined by the tensor descriptor type.
550-
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
559+
// The layout is strictly determined by the payload type.
560+
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
561+
if (!payloadTy) {
562+
load.emitWarning("Not propagating, non-vector payload supplied.");
563+
return;
564+
}
565+
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
551566

552567
// Mask operand should have 1D default layout.
553568
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
554569

555570
// Propagate the new layout to the tensor descriptor operand.
556-
propagateIfChanged(operands[0], operands[0]->meet(layout));
557-
// Propagate the new layout to the mask operand.
571+
if (isa<xegpu::TensorDescType>(load.getSourceType()))
572+
propagateIfChanged(operands[0], operands[0]->meet(layout));
573+
// Propagate the new layout to the mask and optional offset operand.
558574
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
575+
if (load.getOffsets())
576+
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
559577
}
560578

561579
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -572,31 +590,39 @@ void LayoutInfoPropagation::visitCreateDescOp(
572590
propagateIfChanged(operands[1], operands[1]->meet(layout));
573591
}
574592

575-
/// Set the layout for the value, tensor descriptor, and mask operands in the
576-
/// StoreScatterOp.
593+
/// Set the layout for the value, tensor descriptor, offset and mask operands in
594+
/// the StoreScatterOp.
577595
void LayoutInfoPropagation::visitStoreScatterOp(
578596
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
579597
ArrayRef<const LayoutInfoLattice *> results) {
580598
// Currently, for 2D StoreScatterOp we expect that the height dimension of
581599
// the tensor descriptor is equal to the subgroup size. This is ensured by
582600
// the op verifier.
583-
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
584-
if (tdescShape.size() > 1)
601+
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
602+
if (!payloadTy) {
603+
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
604+
return;
605+
}
606+
auto payloadShape = payloadTy.getShape();
607+
if (payloadShape.size() > 1)
585608
assert(
586-
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
609+
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
587610
"Expected the first dimension of 2D tensor descriptor to be equal to "
588611
"subgroup size.");
589612

590-
LayoutInfo layout =
591-
getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
613+
LayoutInfo payloadLayout =
614+
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
592615

593-
// Propagate the value layout.
594-
propagateIfChanged(operands[0], operands[0]->meet(layout));
595-
// Propagate the tensor descriptor layout.
596-
propagateIfChanged(operands[1], operands[1]->meet(layout));
597-
// Use default 1D layout for mask operand.
598616
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
617+
// Propagate the payload operand layout
618+
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
619+
// Propagate the destination (if tdesc) operand layout
620+
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
621+
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
622+
// Propagate the new layout to the mask and optional offset operand.
599623
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
624+
if (storeScatter.getOffsets())
625+
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
600626
}
601627

602628
namespace {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 199 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807
}
808808
};
809809

810+
/// Distribute a scattered store op. The offsets argument is required.
811+
/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
812+
/// The layouts are fixed and implicit: one offset/mask per lane.
813+
/// The pass changes the offset/mask vector shapes to a
814+
/// single-element vector, **it is assumed that their producer will also be
815+
/// distributed**. The payload vector also has a fixed distribution:
816+
/// no chunk size -> vector of one element.
817+
/// chunk size -> vector of the innermost dimension of the SG-payload.
818+
/// Example 1 (no chunk size):
819+
/// %mask = producer_op : vector<16xi1>
820+
/// %offset = producer_op : vector<16xindex>
821+
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822+
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
823+
/// To
824+
/// %mask = producer_op : vector<1xi1>
825+
/// %offset = producer_op : vector<1xindex>
826+
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827+
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
828+
/// Example 2 (chunk size, same mask and offsets):
829+
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830+
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831+
/// To
832+
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833+
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834+
struct StoreDistribution final : public gpu::WarpDistributionPattern {
835+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
836+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
837+
PatternRewriter &rewriter) const override {
838+
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
839+
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840+
if (!storeScatterOp)
841+
return failure();
842+
auto offsets = storeScatterOp.getOffsets();
843+
if (!offsets || !isa<VectorType>(offsets.getType()))
844+
return rewriter.notifyMatchFailure(
845+
storeScatterOp, "Store op must have a vector of offsets argument");
846+
VectorType offsetsTy = cast<VectorType>(offsets.getType());
847+
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
848+
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
849+
return rewriter.notifyMatchFailure(storeScatterOp,
850+
"Expected 1D offsets and mask vector");
851+
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
852+
if (storeVecTy.getRank() > 2)
853+
return rewriter.notifyMatchFailure(
854+
storeScatterOp, "Expected at most 2D result at SG level");
855+
856+
std::string layoutPayloadName =
857+
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
858+
std::string layoutOffsetsName =
859+
xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
860+
std::string layoutMaskName =
861+
xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
862+
863+
xegpu::LayoutAttr layoutPayload =
864+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
865+
xegpu::LayoutAttr layoutOffsets =
866+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
867+
xegpu::LayoutAttr layoutMask =
868+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
869+
870+
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871+
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
872+
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873+
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
874+
FailureOr<VectorType> distMaskByWarpOpOrFailure =
875+
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
876+
if (failed(distStoreVecByWarpOpOrFailure) ||
877+
failed(distOffsetsByWarpOpOrFailure) ||
878+
failed(distMaskByWarpOpOrFailure)) {
879+
return rewriter.notifyMatchFailure(
880+
storeScatterOp,
881+
"Some vector operands have no layouts, using defaults instead.");
882+
}
883+
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
884+
VectorType expectedPayloadTy = VectorType::get(
885+
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
886+
887+
SmallVector<size_t> newRetIndices;
888+
SmallVector<Value> operands = storeScatterOp->getOperands();
889+
SmallVector<Type> operandTypesToYield = {
890+
expectedPayloadTy, operands[1].getType(),
891+
distOffsetsByWarpOpOrFailure.value(),
892+
distMaskByWarpOpOrFailure.value()};
893+
894+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
895+
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896+
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
897+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
898+
899+
rewriter.setInsertionPointAfter(newWarpOp);
900+
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
901+
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
902+
storeScatterOp->getAttrs());
903+
xegpu::removeLayoutAttrs(newOp);
904+
rewriter.eraseOp(storeScatterOp);
905+
return success();
906+
}
907+
};
908+
909+
/// Distribute a scattered load op. The logic and requirements are the same as
910+
/// for the scattered store distribution. The warpOp's payload vector is
911+
/// expected to be distributed by the load's result consumer.
912+
/// Example 1 (no chunk size):
913+
/// %mask = producer_op : vector<16xi1>
914+
/// %offset = producer_op : vector<16xindex>
915+
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916+
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
917+
/// To
918+
/// %mask = producer_op : vector<1xi1>
919+
/// %offset = producer_op : vector<1xindex>
920+
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921+
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
922+
/// Example 2 (chunk size, same mask and offsets):
923+
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924+
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925+
/// To
926+
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927+
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928+
struct LoadDistribution final : public gpu::WarpDistributionPattern {
929+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
930+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
931+
PatternRewriter &rewriter) const override {
932+
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
933+
// Check if the yield operand that was produced by the *last* scattered
934+
// load op to avoid sinking it before barriers (maintain memory order).
935+
return isa<xegpu::LoadGatherOp>(op) &&
936+
warpOp.getTerminator()->getPrevNode() == op;
937+
});
938+
if (!producedByLastLoad)
939+
return rewriter.notifyMatchFailure(
940+
warpOp, "The last op is not xegpu::LoadGatherOp");
941+
942+
auto loadGatherOp =
943+
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
944+
auto offsets = loadGatherOp.getOffsets();
945+
if (!offsets || !isa<VectorType>(offsets.getType()) ||
946+
!isa<VectorType>(loadGatherOp.getMask().getType()))
947+
return rewriter.notifyMatchFailure(
948+
loadGatherOp,
949+
"Load op must have a vector arguments for offsets and mask");
950+
VectorType offsetsTy = cast<VectorType>(offsets.getType());
951+
VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
952+
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
953+
return rewriter.notifyMatchFailure(loadGatherOp,
954+
"Expected 1D offsets and mask vector");
955+
// Assume offset and mask producers will be distributed as well.
956+
std::string layoutOffsetsName =
957+
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
958+
std::string layoutMaskName =
959+
xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
960+
961+
xegpu::LayoutAttr layoutOffsets =
962+
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
963+
xegpu::LayoutAttr layoutMask =
964+
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
965+
966+
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967+
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
968+
FailureOr<VectorType> distMaskByWarpOpOrFailure =
969+
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
970+
if (failed(distOffsetsByWarpOpOrFailure) ||
971+
failed(distMaskByWarpOpOrFailure)) {
972+
return rewriter.notifyMatchFailure(
973+
loadGatherOp,
974+
"Some vector operands have no layouts, using defaults instead.");
975+
}
976+
977+
SmallVector<size_t> newRetIndices;
978+
SmallVector<Value> operands = loadGatherOp->getOperands();
979+
SmallVector<Type> operandTypesToYield = {
980+
operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
981+
distMaskByWarpOpOrFailure.value()};
982+
983+
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
984+
VectorType loadVecTy =
985+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
986+
987+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
988+
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
989+
990+
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
991+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
992+
993+
rewriter.setInsertionPointAfter(newWarpOp);
994+
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
995+
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
996+
loadGatherOp->getAttrs());
997+
xegpu::removeLayoutAttrs(newOp);
998+
Value distributedVal = newWarpOp.getResult(operandIdx);
999+
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
1000+
return success();
1001+
}
1002+
};
1003+
8101004
} // namespace
8111005

8121006
namespace {
@@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final
8191013

8201014
void xegpu::populateXeGPUSubgroupDistributePatterns(
8211015
RewritePatternSet &patterns) {
822-
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
823-
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824-
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825-
patterns.getContext());
1016+
patterns
1017+
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018+
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019+
GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020+
patterns.getContext());
8261021
}
8271022

8281023
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
162162
return
163163
}
164164

165+
// -----
166+
// CHECK-LABEL: func.func @scatter_ops_chunksize(
167+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
168+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
169+
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
170+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
171+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
172+
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
173+
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
174+
%1 = arith.constant dense<1>: vector<16xi1>
175+
%offset = arith.constant dense<12> : vector<16xindex>
176+
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
177+
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
178+
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
179+
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
180+
return
181+
}
182+
183+
// -----
184+
// CHECK-LABEL: func.func @scatter_ops(
185+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
186+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
187+
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
188+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
189+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
190+
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
191+
func.func @scatter_ops(%src: memref<256xf16>) {
192+
%1 = arith.constant dense<1>: vector<16xi1>
193+
%offset = arith.constant dense<12> : vector<16xindex>
194+
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
195+
xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
196+
return
197+
}
198+
165199
// -----
166200
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
167201
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {

0 commit comments

Comments
 (0)