diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 5cb47b2accd68..c0c4394f73d4a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) { } /// Helper to get the default layout for a vector type. -static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { +static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, + bool isScattered = false) { // Expecting a 1D or 2D vector. assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && "Expected 1D or 2D vector."); @@ -207,6 +208,14 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { // Packing factor is determined by the element type bitwidth. int packingFactor = 1; unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); + if (isScattered) { + packingFactor = + bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter + ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth + : 1; + return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}), + LaneData({1, packingFactor})); + } if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), @@ -214,7 +223,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { } /// Helper to get the default layout for a vector type. -static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { +static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, + bool isScattered = false) { // Expecting a 1D or 2D vector. assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && "Expected 1D or 2D TensorDesc."); @@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - if (tdescTy.isScattered()) { + if (isScattered) { int packingFactor = bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth @@ -541,21 +551,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp( propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } -/// Propagate the layout of the result to the tensor descriptor and mask +/// Propagate the layout of the result to the tensor descriptor, mask and offset /// operands in LoadGatherOp. void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef operands, ArrayRef results) { - // The layout is strictly determined by the tensor descriptor type. - LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType()); + // The layout is strictly determined by the payload type. + auto payloadTy = dyn_cast(load.getValueType()); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); // Mask operand should have 1D default layout. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(layout)); - // Propagate the new layout to the mask operand. + if (isa(load.getSourceType())) + propagateIfChanged(operands[0], operands[0]->meet(layout)); + // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); + if (load.getOffsets()) + propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); } /// Propagate the layout of the descriptor to the vector offset operand in @@ -572,31 +590,39 @@ void LayoutInfoPropagation::visitCreateDescOp( propagateIfChanged(operands[1], operands[1]->meet(layout)); } -/// Set the layout for the value, tensor descriptor, and mask operands in the -/// StoreScatterOp. +/// Set the layout for the value, tensor descriptor, offset and mask operands in +/// the StoreScatterOp. void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef operands, ArrayRef results) { // Currently, for 2D StoreScatterOp we expect that the height dimension of // the tensor descriptor is equal to the subgroup size. This is ensured by // the op verifier. - ArrayRef tdescShape = storeScatter.getTensorDescType().getShape(); - if (tdescShape.size() > 1) + auto payloadTy = dyn_cast(storeScatter.getValueType()); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) assert( - tdescShape[0] == xegpu::targetinfo::subgroupSize && + payloadShape[0] == xegpu::targetinfo::subgroupSize && "Expected the first dimension of 2D tensor descriptor to be equal to " "subgroup size."); - LayoutInfo layout = - getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType()); + LayoutInfo payloadLayout = + getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true); - // Propagate the value layout. - propagateIfChanged(operands[0], operands[0]->meet(layout)); - // Propagate the tensor descriptor layout. - propagateIfChanged(operands[1], operands[1]->meet(layout)); - // Use default 1D layout for mask operand. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); + // Propagate the payload operand layout + propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); + // Propagate the destination (if tdesc) operand layout + if (isa(storeScatter.getDestType())) + propagateIfChanged(operands[1], operands[1]->meet(payloadLayout)); + // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); + if (storeScatter.getOffsets()) + propagateIfChanged(operands[3], operands[3]->meet(maskLayout)); } namespace { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index dddb5eaece2cb..6b8367dd8c201 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +/// Distribute a scattered store op. The offsets argument is required. +/// Both offset and mask vectors must be 1D and have #subgroup_size elements. +/// The layouts are fixed and implicit: one offset/mask per lane. +/// The pass changes the offset/mask vector shapes to a +/// single-element vector, **it is assumed that their producer will also be +/// distributed**. The payload vector also has a fixed distribution: +/// no chunk size -> vector of one element. +/// chunk size -> vector of the innermost dimension of the SG-payload. +/// Example 1 (no chunk size): +/// %mask = producer_op : vector<16xi1> +/// %offset = producer_op : vector<16xindex> +/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>, +/// memref<256xf16>, vector<16xindex>, vector<16xi1> +/// To +/// %mask = producer_op : vector<1xi1> +/// %offset = producer_op : vector<1xindex> +/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>, +/// memref<256xf16>, vector<1xindex>, vector<1xi1> +/// Example 2 (chunk size, same mask and offsets): +/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> : +/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +/// To +/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> : +/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +struct StoreDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + Operation *lastNode = warpOp.getTerminator()->getPrevNode(); + auto storeScatterOp = dyn_cast_or_null(lastNode); + if (!storeScatterOp) + return failure(); + auto offsets = storeScatterOp.getOffsets(); + if (!offsets || !isa(offsets.getType())) + return rewriter.notifyMatchFailure( + storeScatterOp, "Store op must have a vector of offsets argument"); + VectorType offsetsTy = cast(offsets.getType()); + VectorType maskTy = cast(storeScatterOp.getMask().getType()); + if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1) + return rewriter.notifyMatchFailure(storeScatterOp, + "Expected 1D offsets and mask vector"); + VectorType storeVecTy = cast(storeScatterOp.getValueType()); + if (storeVecTy.getRank() > 2) + return rewriter.notifyMatchFailure( + storeScatterOp, "Expected at most 2D result at SG level"); + + std::string layoutPayloadName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(0)); + std::string layoutOffsetsName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(2)); + std::string layoutMaskName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(3)); + + xegpu::LayoutAttr layoutPayload = + storeScatterOp->getAttrOfType(layoutPayloadName); + xegpu::LayoutAttr layoutOffsets = + storeScatterOp->getAttrOfType(layoutOffsetsName); + xegpu::LayoutAttr layoutMask = + storeScatterOp->getAttrOfType(layoutMaskName); + + FailureOr distStoreVecByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy); + FailureOr distOffsetsByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy); + FailureOr distMaskByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy); + if (failed(distStoreVecByWarpOpOrFailure) || + failed(distOffsetsByWarpOpOrFailure) || + failed(distMaskByWarpOpOrFailure)) { + return rewriter.notifyMatchFailure( + storeScatterOp, + "Some vector operands have no layouts, using defaults instead."); + } + VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value(); + VectorType expectedPayloadTy = VectorType::get( + {distPayloadTy.getNumElements()}, distPayloadTy.getElementType()); + + SmallVector newRetIndices; + SmallVector operands = storeScatterOp->getOperands(); + SmallVector operandTypesToYield = { + expectedPayloadTy, operands[1].getType(), + distOffsetsByWarpOpOrFailure.value(), + distMaskByWarpOpOrFailure.value()}; + + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypesToYield, newRetIndices); + SmallVector newStoreScatterOpOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + rewriter.setInsertionPointAfter(newWarpOp); + xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create( + rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands, + storeScatterOp->getAttrs()); + xegpu::removeLayoutAttrs(newOp); + rewriter.eraseOp(storeScatterOp); + return success(); + } +}; + +/// Distribute a scattered load op. The logic and requirements are the same as +/// for the scattered store distribution. The warpOp's payload vector is +/// expected to be distributed by the load's result consumer. +/// Example 1 (no chunk size): +/// %mask = producer_op : vector<16xi1> +/// %offset = producer_op : vector<16xindex> +/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>, +/// vector<16xindex>, vector<16xi1> -> vector<16xf16> +/// To +/// %mask = producer_op : vector<1xi1> +/// %offset = producer_op : vector<1xindex> +/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>, +/// vector<1xindex>, vector<1xi1> -> vector<1xf16> +/// Example 2 (chunk size, same mask and offsets): +/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> : +/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +/// To +/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> : +/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +struct LoadDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { + // Check if the yield operand that was produced by the *last* scattered + // load op to avoid sinking it before barriers (maintain memory order). + return isa(op) && + warpOp.getTerminator()->getPrevNode() == op; + }); + if (!producedByLastLoad) + return rewriter.notifyMatchFailure( + warpOp, "The last op is not xegpu::LoadGatherOp"); + + auto loadGatherOp = + producedByLastLoad->get().getDefiningOp(); + auto offsets = loadGatherOp.getOffsets(); + if (!offsets || !isa(offsets.getType()) || + !isa(loadGatherOp.getMask().getType())) + return rewriter.notifyMatchFailure( + loadGatherOp, + "Load op must have a vector arguments for offsets and mask"); + VectorType offsetsTy = cast(offsets.getType()); + VectorType maskTy = cast(loadGatherOp.getMask().getType()); + if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1) + return rewriter.notifyMatchFailure(loadGatherOp, + "Expected 1D offsets and mask vector"); + // Assume offset and mask producers will be distributed as well. + std::string layoutOffsetsName = + xegpu::getLayoutName(loadGatherOp->getOpOperand(1)); + std::string layoutMaskName = + xegpu::getLayoutName(loadGatherOp->getOpOperand(2)); + + xegpu::LayoutAttr layoutOffsets = + loadGatherOp->getAttrOfType(layoutOffsetsName); + xegpu::LayoutAttr layoutMask = + loadGatherOp->getAttrOfType(layoutMaskName); + + FailureOr distOffsetsByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy); + FailureOr distMaskByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy); + if (failed(distOffsetsByWarpOpOrFailure) || + failed(distMaskByWarpOpOrFailure)) { + return rewriter.notifyMatchFailure( + loadGatherOp, + "Some vector operands have no layouts, using defaults instead."); + } + + SmallVector newRetIndices; + SmallVector operands = loadGatherOp->getOperands(); + SmallVector operandTypesToYield = { + operands[0].getType(), distOffsetsByWarpOpOrFailure.value(), + distMaskByWarpOpOrFailure.value()}; + + const unsigned operandIdx = producedByLastLoad->getOperandNumber(); + VectorType loadVecTy = + cast(warpOp.getResult(operandIdx).getType()); + + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypesToYield, newRetIndices); + + SmallVector newLoadGatherOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + rewriter.setInsertionPointAfter(newWarpOp); + xegpu::LoadGatherOp newOp = rewriter.create( + newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands, + loadGatherOp->getAttrs()); + xegpu::removeLayoutAttrs(newOp); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0)); + return success(); + } +}; + } // namespace namespace { @@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final void xegpu::populateXeGPUSubgroupDistributePatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } void XeGPUSubgroupDistributePass::runOnOperation() { diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 0214d84f2c16f..cba3f0bd690c3 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) { return } +// ----- +// CHECK-LABEL: func.func @scatter_ops_chunksize( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops_chunksize(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> + : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} + +// ----- +// CHECK-LABEL: func.func @scatter_ops( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} + // ----- // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 54ef56e013abb..a39aa90bbe3a8 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -319,3 +319,39 @@ gpu.module @test { gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) { + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_result_0 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops(%src: memref<256xf16>) { + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 { + layout_result_0 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +}