Skip to content

Commit aca4002

Browse files
authored
Revert "[MLIR][XeGPU] Scattered ops sg-to-wi distribution (#154949)"
This reverts commit 5777f71.
1 parent e20ce96 commit aca4002

File tree

4 files changed

+24
-315
lines changed

4 files changed

+24
-315
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
194194
}
195195

196196
/// Helper to get the default layout for a vector type.
197-
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
198-
bool isScattered = false) {
197+
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
199198
// Expecting a 1D or 2D vector.
200199
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
201200
"Expected 1D or 2D vector.");
@@ -208,23 +207,14 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
208207
// Packing factor is determined by the element type bitwidth.
209208
int packingFactor = 1;
210209
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
211-
if (isScattered) {
212-
packingFactor =
213-
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
214-
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
215-
: 1;
216-
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
217-
LaneData({1, packingFactor}));
218-
}
219210
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
220211
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
221212
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
222213
LaneData({1, packingFactor}));
223214
}
224215

225216
/// Helper to get the default layout for a vector type.
226-
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
227-
bool isScattered = false) {
217+
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
228218
// Expecting a 1D or 2D vector.
229219
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
230220
"Expected 1D or 2D TensorDesc.");
@@ -237,7 +227,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
237227
// Packing factor is determined by the element type bitwidth.
238228
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
239229

240-
if (isScattered) {
230+
if (tdescTy.isScattered()) {
241231
int packingFactor =
242232
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
243233
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -551,29 +541,21 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
551541
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
552542
}
553543

554-
/// Propagate the layout of the result to the tensor descriptor, mask and offset
544+
/// Propagate the layout of the result to the tensor descriptor and mask
555545
/// operands in LoadGatherOp.
556546
void LayoutInfoPropagation::visitLoadGatherOp(
557547
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
558548
ArrayRef<const LayoutInfoLattice *> results) {
559-
// The layout is strictly determined by the payload type.
560-
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
561-
if (!payloadTy) {
562-
load.emitWarning("Not propagating, non-vector payload supplied.");
563-
return;
564-
}
565-
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
549+
// The layout is strictly determined by the tensor descriptor type.
550+
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
566551

567552
// Mask operand should have 1D default layout.
568553
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
569554

570555
// Propagate the new layout to the tensor descriptor operand.
571-
if (isa<xegpu::TensorDescType>(load.getSourceType()))
572-
propagateIfChanged(operands[0], operands[0]->meet(layout));
573-
// Propagate the new layout to the mask and optional offset operand.
556+
propagateIfChanged(operands[0], operands[0]->meet(layout));
557+
// Propagate the new layout to the mask operand.
574558
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
575-
if (load.getOffsets())
576-
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
577559
}
578560

579561
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -590,39 +572,31 @@ void LayoutInfoPropagation::visitCreateDescOp(
590572
propagateIfChanged(operands[1], operands[1]->meet(layout));
591573
}
592574

593-
/// Set the layout for the value, tensor descriptor, offset and mask operands in
594-
/// the StoreScatterOp.
575+
/// Set the layout for the value, tensor descriptor, and mask operands in the
576+
/// StoreScatterOp.
595577
void LayoutInfoPropagation::visitStoreScatterOp(
596578
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
597579
ArrayRef<const LayoutInfoLattice *> results) {
598580
// Currently, for 2D StoreScatterOp we expect that the height dimension of
599581
// the tensor descriptor is equal to the subgroup size. This is ensured by
600582
// the op verifier.
601-
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
602-
if (!payloadTy) {
603-
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
604-
return;
605-
}
606-
auto payloadShape = payloadTy.getShape();
607-
if (payloadShape.size() > 1)
583+
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
584+
if (tdescShape.size() > 1)
608585
assert(
609-
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
586+
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
610587
"Expected the first dimension of 2D tensor descriptor to be equal to "
611588
"subgroup size.");
612589

613-
LayoutInfo payloadLayout =
614-
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
590+
LayoutInfo layout =
591+
getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
615592

593+
// Propagate the value layout.
594+
propagateIfChanged(operands[0], operands[0]->meet(layout));
595+
// Propagate the tensor descriptor layout.
596+
propagateIfChanged(operands[1], operands[1]->meet(layout));
597+
// Use default 1D layout for mask operand.
616598
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
617-
// Propagate the payload operand layout
618-
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
619-
// Propagate the destination (if tdesc) operand layout
620-
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
621-
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
622-
// Propagate the new layout to the mask and optional offset operand.
623599
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
624-
if (storeScatter.getOffsets())
625-
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
626600
}
627601

628602
namespace {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 4 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -807,200 +807,6 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807
}
808808
};
809809

810-
/// Distribute a scattered store op. The offsets argument is required.
811-
/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
812-
/// The layouts are fixed and implicit: one offset/mask per lane.
813-
/// The pass changes the offset/mask vector shapes to a
814-
/// single-element vector, **it is assumed that their producer will also be
815-
/// distributed**. The payload vector also has a fixed distribution:
816-
/// no chunk size -> vector of one element.
817-
/// chunk size -> vector of the innermost dimension of the SG-payload.
818-
/// Example 1 (no chunk size):
819-
/// %mask = producer_op : vector<16xi1>
820-
/// %offset = producer_op : vector<16xindex>
821-
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822-
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
823-
/// To
824-
/// %mask = producer_op : vector<1xi1>
825-
/// %offset = producer_op : vector<1xindex>
826-
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827-
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
828-
/// Example 2 (chunk size, same mask and offsets):
829-
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830-
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831-
/// To
832-
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833-
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834-
struct StoreDistribution final : public gpu::WarpDistributionPattern {
835-
using gpu::WarpDistributionPattern::WarpDistributionPattern;
836-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
837-
PatternRewriter &rewriter) const override {
838-
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
839-
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840-
if (!storeScatterOp)
841-
return failure();
842-
auto offsets = storeScatterOp.getOffsets();
843-
if (!offsets || !isa<VectorType>(offsets.getType()))
844-
return rewriter.notifyMatchFailure(
845-
storeScatterOp, "Store op must have a vector of offsets argument");
846-
VectorType offsetsTy = cast<VectorType>(offsets.getType());
847-
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
848-
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
849-
return rewriter.notifyMatchFailure(storeScatterOp,
850-
"Expected 1D offsets and mask vector");
851-
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
852-
if (storeVecTy.getRank() > 2)
853-
return rewriter.notifyMatchFailure(
854-
storeScatterOp, "Expected at most 2D result at SG level");
855-
856-
std::string layoutPayloadName =
857-
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
858-
std::string layoutOffsetsName =
859-
xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
860-
std::string layoutMaskName =
861-
xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
862-
863-
xegpu::LayoutAttr layoutPayload =
864-
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
865-
xegpu::LayoutAttr layoutOffsets =
866-
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
867-
xegpu::LayoutAttr layoutMask =
868-
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
869-
870-
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871-
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
872-
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873-
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
874-
FailureOr<VectorType> distMaskByWarpOpOrFailure =
875-
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
876-
if (failed(distStoreVecByWarpOpOrFailure) ||
877-
failed(distOffsetsByWarpOpOrFailure) ||
878-
failed(distMaskByWarpOpOrFailure)) {
879-
return rewriter.notifyMatchFailure(
880-
storeScatterOp,
881-
"Some vector operands have no layouts, using defaults instead.");
882-
}
883-
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
884-
VectorType expectedPayloadTy = VectorType::get(
885-
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
886-
887-
SmallVector<size_t> newRetIndices;
888-
SmallVector<Value> operands = storeScatterOp->getOperands();
889-
SmallVector<Type> operandTypesToYield = {
890-
expectedPayloadTy, operands[1].getType(),
891-
distOffsetsByWarpOpOrFailure.value(),
892-
distMaskByWarpOpOrFailure.value()};
893-
894-
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
895-
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896-
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
897-
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
898-
899-
rewriter.setInsertionPointAfter(newWarpOp);
900-
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
901-
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
902-
storeScatterOp->getAttrs());
903-
xegpu::removeLayoutAttrs(newOp);
904-
rewriter.eraseOp(storeScatterOp);
905-
return success();
906-
}
907-
};
908-
909-
/// Distribute a scattered load op. The logic and requirements are the same as
910-
/// for the scattered store distribution. The warpOp's payload vector is
911-
/// expected to be distributed by the load's result consumer.
912-
/// Example 1 (no chunk size):
913-
/// %mask = producer_op : vector<16xi1>
914-
/// %offset = producer_op : vector<16xindex>
915-
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916-
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
917-
/// To
918-
/// %mask = producer_op : vector<1xi1>
919-
/// %offset = producer_op : vector<1xindex>
920-
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921-
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
922-
/// Example 2 (chunk size, same mask and offsets):
923-
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924-
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925-
/// To
926-
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927-
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928-
struct LoadDistribution final : public gpu::WarpDistributionPattern {
929-
using gpu::WarpDistributionPattern::WarpDistributionPattern;
930-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
931-
PatternRewriter &rewriter) const override {
932-
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
933-
// Check if the yield operand that was produced by the *last* scattered
934-
// load op to avoid sinking it before barriers (maintain memory order).
935-
return isa<xegpu::LoadGatherOp>(op) &&
936-
warpOp.getTerminator()->getPrevNode() == op;
937-
});
938-
if (!producedByLastLoad)
939-
return rewriter.notifyMatchFailure(
940-
warpOp, "The last op is not xegpu::LoadGatherOp");
941-
942-
auto loadGatherOp =
943-
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
944-
auto offsets = loadGatherOp.getOffsets();
945-
if (!offsets || !isa<VectorType>(offsets.getType()) ||
946-
!isa<VectorType>(loadGatherOp.getMask().getType()))
947-
return rewriter.notifyMatchFailure(
948-
loadGatherOp,
949-
"Load op must have a vector arguments for offsets and mask");
950-
VectorType offsetsTy = cast<VectorType>(offsets.getType());
951-
VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
952-
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
953-
return rewriter.notifyMatchFailure(loadGatherOp,
954-
"Expected 1D offsets and mask vector");
955-
// Assume offset and mask producers will be distributed as well.
956-
std::string layoutOffsetsName =
957-
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
958-
std::string layoutMaskName =
959-
xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
960-
961-
xegpu::LayoutAttr layoutOffsets =
962-
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
963-
xegpu::LayoutAttr layoutMask =
964-
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
965-
966-
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967-
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
968-
FailureOr<VectorType> distMaskByWarpOpOrFailure =
969-
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
970-
if (failed(distOffsetsByWarpOpOrFailure) ||
971-
failed(distMaskByWarpOpOrFailure)) {
972-
return rewriter.notifyMatchFailure(
973-
loadGatherOp,
974-
"Some vector operands have no layouts, using defaults instead.");
975-
}
976-
977-
SmallVector<size_t> newRetIndices;
978-
SmallVector<Value> operands = loadGatherOp->getOperands();
979-
SmallVector<Type> operandTypesToYield = {
980-
operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
981-
distMaskByWarpOpOrFailure.value()};
982-
983-
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
984-
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
985-
986-
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
987-
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
988-
989-
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
990-
VectorType loadVecTy =
991-
cast<VectorType>(warpOp.getResult(operandIdx).getType());
992-
993-
rewriter.setInsertionPointAfter(newWarpOp);
994-
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
995-
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
996-
loadGatherOp->getAttrs());
997-
xegpu::removeLayoutAttrs(newOp);
998-
Value distributedVal = newWarpOp.getResult(operandIdx);
999-
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
1000-
return success();
1001-
}
1002-
};
1003-
1004810
} // namespace
1005811

1006812
namespace {
@@ -1013,11 +819,10 @@ struct XeGPUSubgroupDistributePass final
1013819

1014820
void xegpu::populateXeGPUSubgroupDistributePatterns(
1015821
RewritePatternSet &patterns) {
1016-
patterns
1017-
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018-
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019-
GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020-
patterns.getContext());
822+
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
823+
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824+
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825+
patterns.getContext());
1021826
}
1022827

1023828
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -162,40 +162,6 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
162162
return
163163
}
164164

165-
// -----
166-
// CHECK-LABEL: func.func @scatter_ops_chunksize(
167-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
168-
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
169-
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
170-
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
171-
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
172-
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
173-
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
174-
%1 = arith.constant dense<1>: vector<16xi1>
175-
%offset = arith.constant dense<12> : vector<16xindex>
176-
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
177-
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
178-
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
179-
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
180-
return
181-
}
182-
183-
// -----
184-
// CHECK-LABEL: func.func @scatter_ops(
185-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
186-
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
187-
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
188-
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
189-
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
190-
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
191-
func.func @scatter_ops(%src: memref<256xf16>) {
192-
%1 = arith.constant dense<1>: vector<16xi1>
193-
%offset = arith.constant dense<12> : vector<16xindex>
194-
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
195-
xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
196-
return
197-
}
198-
199165
// -----
200166
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
201167
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {

0 commit comments

Comments
 (0)