From e174b696ecdaf379df6807c600d5dceaf797c74f Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Sat, 23 Aug 2025 10:22:34 +0000 Subject: [PATCH 1/8] [MLIR][XeGPU] Scattered ops sg-to-wi distribution --- .../Transforms/XeGPUSubgroupDistribute.cpp | 141 +++++++++++++++++- .../Dialect/XeGPU/subgroup-distribute.mlir | 42 ++++++ 2 files changed, 179 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index dddb5eaece2cb..3c3a52581ce90 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +struct StoreDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + Operation *lastNode = yield->getPrevNode(); + auto storeScatterOp = dyn_cast_or_null(lastNode); + if (!storeScatterOp) + return failure(); + else if (!storeScatterOp.getOffsets()) + return rewriter.notifyMatchFailure(storeScatterOp, + "Store op must have offsets argument"); + else if (cast(storeScatterOp.getOffsets().getType()) + .getRank() != 1) + return rewriter.notifyMatchFailure(storeScatterOp, + "Expected 1D offsets vector"); + + VectorType storeVecTy = + cast(storeScatterOp.getValue().getType()); + assert(storeVecTy.getRank() <= 2 && + "Expected at most 2D result at SG level"); + VectorType distStoreVecTy; + if (storeVecTy.getRank() == 2) + distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0); + else // rank 1 + distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1); + + SmallVector newRetIndices; + SmallVector operands = + llvm::to_vector_of(storeScatterOp->getOperands()); + SmallVector operandTypes = + llvm::to_vector_of(storeScatterOp->getOperandTypes()); + operandTypes[0] = distStoreVecTy; + + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector newStoreScatterOpOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + Value offsetsVec = newStoreScatterOpOperands[2]; + Value maskVec = newStoreScatterOpOperands[3]; + + auto loc = newWarpOp.getLoc(); + Value laneId = warpOp.getLaneid(); + rewriter.setInsertionPointAfter(newWarpOp); + Value laneOffset = + vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId); + laneOffset = vector::BroadcastOp::create( + rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset); + Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId); + laneMask = vector::BroadcastOp::create( + rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask); + newStoreScatterOpOperands[2] = laneOffset; + newStoreScatterOpOperands[3] = laneMask; + + xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create( + rewriter, loc, TypeRange{}, newStoreScatterOpOperands, + storeScatterOp->getAttrs()); + xegpu::removeLayoutAttrs(newOp); + rewriter.eraseOp(storeScatterOp); + return success(); + } +}; + +struct LoadDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) { + if (!isa(op)) + return false; + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + return yield->getPrevNode() == op; + }); + if (!yieldOperand) + return rewriter.notifyMatchFailure( + warpOp, "warp result is not a xegpu::LoadGatherOp op"); + + auto loadGatherOp = + yieldOperand->get().getDefiningOp(); + if (!loadGatherOp.getOffsets()) + return rewriter.notifyMatchFailure(loadGatherOp, + "Load op must have offsets argument"); + else if (cast(loadGatherOp.getOffsets().getType()).getRank() != + 1) + return rewriter.notifyMatchFailure(loadGatherOp, + "Expected 1D offsets vector"); + + SmallVector newRetIndices; + SmallVector operands = + llvm::to_vector_of(loadGatherOp->getOperands()); + SmallVector operandTypes = + llvm::to_vector_of(loadGatherOp->getOperandTypes()); + + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + + SmallVector newLoadGatherOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + const unsigned operandIdx = yieldOperand->getOperandNumber(); + VectorType loadVecTy = + cast(warpOp.getResult(operandIdx).getType()); + assert(loadVecTy.getRank() == 1 && "Expected a distributed vector"); + + Value offsetsVec = newLoadGatherOperands[1]; + Value maskVec = newLoadGatherOperands[2]; + auto loc = newWarpOp.getLoc(); + Value laneId = warpOp.getLaneid(); + rewriter.setInsertionPointAfter(newWarpOp); + Value laneOffset = + vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId); + laneOffset = vector::BroadcastOp::create( + rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset); + Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId); + laneMask = vector::BroadcastOp::create( + rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask); + newLoadGatherOperands[1] = laneOffset; + newLoadGatherOperands[2] = laneMask; + + xegpu::LoadGatherOp newOp = rewriter.create( + loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs()); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0)); + return success(); + } +}; + } // namespace namespace { @@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final void xegpu::populateXeGPUSubgroupDistributePatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } void XeGPUSubgroupDistributePass::runOnOperation() { @@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa(operand.get().getType())) continue; + if (isa(op)) + continue; auto layout = xegpu::getDistributeLayoutAttrOfType(operand); if (!layout) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 54ef56e013abb..b319162dc3f25 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -319,3 +319,45 @@ gpu.module @test { gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex> +// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, +// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, +// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) { + %1 = arith.constant dense<1>: vector<16xi1> + %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex> +// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, +// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, +// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) { + %1 = arith.constant dense<1>: vector<16xi1> + %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} From 6d2296888df03dfdf08c82b3a39bb890ed465d4c Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Mon, 25 Aug 2025 18:04:16 +0000 Subject: [PATCH 2/8] Assume distributable offset and mask producers --- .../Transforms/XeGPUSubgroupDistribute.cpp | 54 ++++++------------- .../Dialect/XeGPU/subgroup-distribute.mlir | 30 +++++------ 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 3c3a52581ce90..cf2c933e16cb4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { auto storeScatterOp = dyn_cast_or_null(lastNode); if (!storeScatterOp) return failure(); - else if (!storeScatterOp.getOffsets()) + if (!storeScatterOp.getOffsets()) return rewriter.notifyMatchFailure(storeScatterOp, "Store op must have offsets argument"); - else if (cast(storeScatterOp.getOffsets().getType()) - .getRank() != 1) + VectorType offsetsTy = + cast(storeScatterOp.getOffsets().getType()); + if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(storeScatterOp, "Expected 1D offsets vector"); - VectorType storeVecTy = cast(storeScatterOp.getValue().getType()); assert(storeVecTy.getRank() <= 2 && @@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1); SmallVector newRetIndices; - SmallVector operands = - llvm::to_vector_of(storeScatterOp->getOperands()); + SmallVector operands = storeScatterOp->getOperands(); SmallVector operandTypes = llvm::to_vector_of(storeScatterOp->getOperandTypes()); operandTypes[0] = distStoreVecTy; + // Assume offset and mask pproducers will be distributed as well. + operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); + operandTypes[3] = VectorType::get( + {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType())); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypes, newRetIndices); SmallVector newStoreScatterOpOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - Value offsetsVec = newStoreScatterOpOperands[2]; - Value maskVec = newStoreScatterOpOperands[3]; - auto loc = newWarpOp.getLoc(); - Value laneId = warpOp.getLaneid(); rewriter.setInsertionPointAfter(newWarpOp); - Value laneOffset = - vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId); - laneOffset = vector::BroadcastOp::create( - rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset); - Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId); - laneMask = vector::BroadcastOp::create( - rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask); - newStoreScatterOpOperands[2] = laneOffset; - newStoreScatterOpOperands[3] = laneMask; - xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create( rewriter, loc, TypeRange{}, newStoreScatterOpOperands, storeScatterOp->getAttrs()); @@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { if (!loadGatherOp.getOffsets()) return rewriter.notifyMatchFailure(loadGatherOp, "Load op must have offsets argument"); - else if (cast(loadGatherOp.getOffsets().getType()).getRank() != - 1) + VectorType offsetsTy = + cast(loadGatherOp.getOffsets().getType()); + if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(loadGatherOp, "Expected 1D offsets vector"); SmallVector newRetIndices; - SmallVector operands = - llvm::to_vector_of(loadGatherOp->getOperands()); + SmallVector operands = loadGatherOp->getOperands(); SmallVector operandTypes = llvm::to_vector_of(loadGatherOp->getOperandTypes()); + // Assume offset and mask pproducers will be distributed as well. + operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); + operandTypes[2] = VectorType::get( + {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType())); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypes, newRetIndices); @@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { cast(warpOp.getResult(operandIdx).getType()); assert(loadVecTy.getRank() == 1 && "Expected a distributed vector"); - Value offsetsVec = newLoadGatherOperands[1]; - Value maskVec = newLoadGatherOperands[2]; auto loc = newWarpOp.getLoc(); - Value laneId = warpOp.getLaneid(); rewriter.setInsertionPointAfter(newWarpOp); - Value laneOffset = - vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId); - laneOffset = vector::BroadcastOp::create( - rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset); - Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId); - laneMask = vector::BroadcastOp::create( - rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask); - newLoadGatherOperands[1] = laneOffset; - newLoadGatherOperands[2] = laneMask; - xegpu::LoadGatherOp newOp = rewriter.create( loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs()); Value distributedVal = newWarpOp.getResult(operandIdx); diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index b319162dc3f25..1c4684681b62b 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -323,19 +323,18 @@ gpu.module @test { // ----- // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> -// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id -// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex> -// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex> -// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, // CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> -// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, // CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { - gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) { + gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> - %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } @@ -344,19 +343,18 @@ gpu.module @test { // ----- // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> -// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id -// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex> -// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex> -// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, // CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> -// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, // CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { - gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) { + gpu.func @scatter_ops(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> - %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } From a4d4e66062f797ff5fa1156f8450e55d3f124c46 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Tue, 26 Aug 2025 14:23:30 +0000 Subject: [PATCH 3/8] Address feedback --- .../Transforms/XeGPUSubgroupDistribute.cpp | 123 ++++++++++++------ 1 file changed, 82 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index cf2c933e16cb4..84278964cbb63 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +/// Distribute a scattered store op. The offsets argument is required. +/// Both offset and mask vectors must be 1D and have #subgroup_size elements. +/// The layouts are fixed and implicit: one offset/mask per lane. +/// The pass changes the offset/mask vector shapes to a +/// single-element vector, **it is assumed that their producer will also be +/// distributed**. The payload vector also has a fixed distribution: +/// no chunk size -> vector of one element. +/// chunk size -> vector of the innermost dimension of the SG-payload. +/// Example 1 (no chunk size): +/// %mask = producer_op : vector<16xi1> +/// %offset = producer_op : vector<16xindex> +/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>, +/// memref<256xf16>, vector<16xindex>, vector<16xi1> +/// To +/// %mask = producer_op : vector<1xi1> +/// %offset = producer_op : vector<1xindex> +/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>, +/// memref<256xf16>, vector<1xindex>, vector<1xi1> +/// Example 2 (chunk size, same mask and offsets): +/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> : +/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +/// To +/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> : +/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> struct StoreDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - Operation *lastNode = yield->getPrevNode(); + Operation *lastNode = warpOp.getTerminator()->getPrevNode(); auto storeScatterOp = dyn_cast_or_null(lastNode); if (!storeScatterOp) return failure(); - if (!storeScatterOp.getOffsets()) - return rewriter.notifyMatchFailure(storeScatterOp, - "Store op must have offsets argument"); - VectorType offsetsTy = - cast(storeScatterOp.getOffsets().getType()); + auto offsets = storeScatterOp.getOffsets(); + if (!offsets || !isa(offsets.getType())) + return rewriter.notifyMatchFailure( + storeScatterOp, "Store op must have a vector of offsets argument"); + VectorType offsetsTy = cast(offsets.getType()); if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(storeScatterOp, "Expected 1D offsets vector"); - VectorType storeVecTy = - cast(storeScatterOp.getValue().getType()); + VectorType storeVecTy = cast(storeScatterOp.getValueType()); assert(storeVecTy.getRank() <= 2 && "Expected at most 2D result at SG level"); VectorType distStoreVecTy; @@ -837,23 +858,23 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { SmallVector newRetIndices; SmallVector operands = storeScatterOp->getOperands(); - SmallVector operandTypes = + SmallVector operandTypesToYield = llvm::to_vector_of(storeScatterOp->getOperandTypes()); - operandTypes[0] = distStoreVecTy; - // Assume offset and mask pproducers will be distributed as well. - operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - operandTypes[3] = VectorType::get( + operandTypesToYield[0] = distStoreVecTy; + // Assume offset and mask producers will be distributed as well. + operandTypesToYield[2] = + VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); + operandTypesToYield[3] = VectorType::get( {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType())); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, operands, operandTypes, newRetIndices); + rewriter, warpOp, operands, operandTypesToYield, newRetIndices); SmallVector newStoreScatterOpOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - auto loc = newWarpOp.getLoc(); rewriter.setInsertionPointAfter(newWarpOp); xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create( - rewriter, loc, TypeRange{}, newStoreScatterOpOperands, + rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands, storeScatterOp->getAttrs()); xegpu::removeLayoutAttrs(newOp); rewriter.eraseOp(storeScatterOp); @@ -861,56 +882,75 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { } }; +/// Distribute a scattered load op. The logic and requirements are the same as +/// for the scattered store distribution. The warpOp's payload vector is +/// expected to be distributed by the load's result consumer. +/// Example 1 (no chunk size): +/// %mask = producer_op : vector<16xi1> +/// %offset = producer_op : vector<16xindex> +/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>, +/// vector<16xindex>, vector<16xi1> -> vector<16xf16> +/// To +/// %mask = producer_op : vector<1xi1> +/// %offset = producer_op : vector<1xindex> +/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>, +/// vector<1xindex>, vector<1xi1> -> vector<1xf16> +/// Example 2 (chunk size, same mask and offsets): +/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> : +/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +/// To +/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> : +/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> struct LoadDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) { - if (!isa(op)) - return false; - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - return yield->getPrevNode() == op; + OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { + // Check if the yield operand that was produced by the *last* scattered + // load op to avoid sinking it before barriers (maintain memory order). + return isa(op) && + warpOp.getTerminator()->getPrevNode() == op; }); - if (!yieldOperand) + if (!producedByLastLoad) return rewriter.notifyMatchFailure( - warpOp, "warp result is not a xegpu::LoadGatherOp op"); + warpOp, "The last op is not xegpu::LoadGatherOp"); auto loadGatherOp = - yieldOperand->get().getDefiningOp(); - if (!loadGatherOp.getOffsets()) - return rewriter.notifyMatchFailure(loadGatherOp, - "Load op must have offsets argument"); - VectorType offsetsTy = - cast(loadGatherOp.getOffsets().getType()); + producedByLastLoad->get().getDefiningOp(); + auto offsets = loadGatherOp.getOffsets(); + if (!offsets || !isa(offsets.getType())) + return rewriter.notifyMatchFailure( + loadGatherOp, "Load op must have a vector of offsets argument"); + VectorType offsetsTy = cast(offsets.getType()); if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(loadGatherOp, "Expected 1D offsets vector"); SmallVector newRetIndices; SmallVector operands = loadGatherOp->getOperands(); - SmallVector operandTypes = + SmallVector operandTypesToYield = llvm::to_vector_of(loadGatherOp->getOperandTypes()); - // Assume offset and mask pproducers will be distributed as well. - operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - operandTypes[2] = VectorType::get( - {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType())); + // Assume offset and mask producers will be distributed as well. + operandTypesToYield[1] = + VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); + operandTypesToYield[2] = + VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType())); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, operands, operandTypes, newRetIndices); + rewriter, warpOp, operands, operandTypesToYield, newRetIndices); SmallVector newLoadGatherOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - const unsigned operandIdx = yieldOperand->getOperandNumber(); + const unsigned operandIdx = producedByLastLoad->getOperandNumber(); VectorType loadVecTy = cast(warpOp.getResult(operandIdx).getType()); assert(loadVecTy.getRank() == 1 && "Expected a distributed vector"); - auto loc = newWarpOp.getLoc(); rewriter.setInsertionPointAfter(newWarpOp); xegpu::LoadGatherOp newOp = rewriter.create( - loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs()); + newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands, + loadGatherOp->getAttrs()); Value distributedVal = newWarpOp.getResult(operandIdx); rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0)); return success(); @@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa(operand.get().getType())) continue; + // Vectors operands of these ops have a fixed and implicit layout. if (isa(op)) continue; auto layout = From daa143f5572534839993b880dc441364cbf511fb Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Thu, 28 Aug 2025 14:25:17 +0000 Subject: [PATCH 4/8] Add layout-based distribution --- .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 61 ++++++++---- .../Transforms/XeGPUSubgroupDistribute.cpp | 98 ++++++++++++++----- mlir/test/Dialect/XeGPU/propagate-layout.mlir | 34 +++++++ .../Dialect/XeGPU/subgroup-distribute.mlir | 47 ++++++--- 4 files changed, 182 insertions(+), 58 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 5cb47b2accd68..46c5777d1c157 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) { } /// Helper to get the default layout for a vector type. -static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { +static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, + bool scattered = false) { // Expecting a 1D or 2D vector. assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && "Expected 1D or 2D vector."); @@ -207,6 +208,14 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { // Packing factor is determined by the element type bitwidth. int packingFactor = 1; unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); + if (scattered) { + packingFactor = + bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter + ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth + : 1; + return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}), + LaneData({1, packingFactor})); + } if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), @@ -214,7 +223,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { } /// Helper to get the default layout for a vector type. -static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { +static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, + bool scattered = false) { // Expecting a 1D or 2D vector. assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && "Expected 1D or 2D TensorDesc."); @@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - if (tdescTy.isScattered()) { + if (scattered) { int packingFactor = bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth @@ -541,21 +551,27 @@ void LayoutInfoPropagation::visitVectorBitcastOp( propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } -/// Propagate the layout of the result to the tensor descriptor and mask +/// Propagate the layout of the result to the tensor descriptor, mask and offset /// operands in LoadGatherOp. void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef operands, ArrayRef results) { - // The layout is strictly determined by the tensor descriptor type. - LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType()); + // The layout is strictly determined by the payload type. + auto payloadTy = dyn_cast(load.getValueType()); + assert(payloadTy && "Only vector payload distribution is supported"); + LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); // Mask operand should have 1D default layout. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(layout)); - // Propagate the new layout to the mask operand. + if (isa(load.getSourceType())) + propagateIfChanged(operands[0], operands[0]->meet(layout)); + // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); + if (load.getOffsets()) { + propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); + } } /// Propagate the layout of the descriptor to the vector offset operand in @@ -572,31 +588,36 @@ void LayoutInfoPropagation::visitCreateDescOp( propagateIfChanged(operands[1], operands[1]->meet(layout)); } -/// Set the layout for the value, tensor descriptor, and mask operands in the -/// StoreScatterOp. +/// Set the layout for the value, tensor descriptor, offset and mask operands in +/// the StoreScatterOp. void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef operands, ArrayRef results) { // Currently, for 2D StoreScatterOp we expect that the height dimension of // the tensor descriptor is equal to the subgroup size. This is ensured by // the op verifier. - ArrayRef tdescShape = storeScatter.getTensorDescType().getShape(); - if (tdescShape.size() > 1) + auto payloadTy = dyn_cast(storeScatter.getValueType()); + assert(payloadTy && "Only vector payload distribution is supported"); + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) assert( - tdescShape[0] == xegpu::targetinfo::subgroupSize && + payloadShape[0] == xegpu::targetinfo::subgroupSize && "Expected the first dimension of 2D tensor descriptor to be equal to " "subgroup size."); - LayoutInfo layout = - getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType()); + LayoutInfo payloadLayout = + getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true); - // Propagate the value layout. - propagateIfChanged(operands[0], operands[0]->meet(layout)); - // Propagate the tensor descriptor layout. - propagateIfChanged(operands[1], operands[1]->meet(layout)); - // Use default 1D layout for mask operand. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); + // Propagate the payload operand layout + propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); + // Propagate the destination (if tdesc) operand layout + if (isa(storeScatter.getDestType())) + propagateIfChanged(operands[1], operands[1]->meet(payloadLayout)); + // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); + if (storeScatter.getOffsets()) + propagateIfChanged(operands[3], operands[3]->meet(maskLayout)); } namespace { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 84278964cbb63..9bb0a2160f82e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -844,9 +844,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { return rewriter.notifyMatchFailure( storeScatterOp, "Store op must have a vector of offsets argument"); VectorType offsetsTy = cast(offsets.getType()); - if (offsetsTy.getRank() != 1) + VectorType maskTy = cast(storeScatterOp.getMask().getType()); + if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1) return rewriter.notifyMatchFailure(storeScatterOp, - "Expected 1D offsets vector"); + "Expected 1D offsets and mask vector"); VectorType storeVecTy = cast(storeScatterOp.getValueType()); assert(storeVecTy.getRank() <= 2 && "Expected at most 2D result at SG level"); @@ -855,17 +856,45 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0); else // rank 1 distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1); - - SmallVector newRetIndices; - SmallVector operands = storeScatterOp->getOperands(); - SmallVector operandTypesToYield = - llvm::to_vector_of(storeScatterOp->getOperandTypes()); - operandTypesToYield[0] = distStoreVecTy; // Assume offset and mask producers will be distributed as well. - operandTypesToYield[2] = + VectorType distOffsetsTy = VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - operandTypesToYield[3] = VectorType::get( + VectorType distMaskTy = VectorType::get( {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType())); + std::string layoutPayloadName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(0)); + std::string layoutOffsetsName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(2)); + std::string layoutMaskName = + xegpu::getLayoutName(storeScatterOp->getOpOperand(3)); + + xegpu::LayoutAttr layoutPayload = + storeScatterOp->getAttrOfType(layoutPayloadName); + xegpu::LayoutAttr layoutOffsets = + storeScatterOp->getAttrOfType(layoutOffsetsName); + xegpu::LayoutAttr layoutMask = + storeScatterOp->getAttrOfType(layoutMaskName); + + FailureOr distStoreVecByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy); + FailureOr distOffsetsByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy); + FailureOr distMaskByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy); + if (failed(distStoreVecByWarpOpOrFailure) || + failed(distOffsetsByWarpOpOrFailure) || + failed(distMaskByWarpOpOrFailure)) { + storeScatterOp.emitWarning( + "Some vector operands have no layouts, using defaults instead."); + } + distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy); + distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy); + distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy); + + SmallVector newRetIndices; + SmallVector operands = storeScatterOp->getOperands(); + SmallVector operandTypesToYield = { + distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy}; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypesToYield, newRetIndices); @@ -918,23 +947,47 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { auto loadGatherOp = producedByLastLoad->get().getDefiningOp(); auto offsets = loadGatherOp.getOffsets(); - if (!offsets || !isa(offsets.getType())) + if (!offsets || !isa(offsets.getType()) || + !isa(loadGatherOp.getMask().getType())) return rewriter.notifyMatchFailure( - loadGatherOp, "Load op must have a vector of offsets argument"); + loadGatherOp, + "Load op must have a vector arguments for offsets and mask"); VectorType offsetsTy = cast(offsets.getType()); - if (offsetsTy.getRank() != 1) + VectorType maskTy = cast(loadGatherOp.getMask().getType()); + if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1) return rewriter.notifyMatchFailure(loadGatherOp, - "Expected 1D offsets vector"); + "Expected 1D offsets and mask vector"); + // Assume offset and mask producers will be distributed as well. + VectorType distOffsetsTy = + VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); + VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy)); + + std::string layoutOffsetsName = + xegpu::getLayoutName(loadGatherOp->getOpOperand(1)); + std::string layoutMaskName = + xegpu::getLayoutName(loadGatherOp->getOpOperand(2)); + + xegpu::LayoutAttr layoutOffsets = + loadGatherOp->getAttrOfType(layoutOffsetsName); + xegpu::LayoutAttr layoutMask = + loadGatherOp->getAttrOfType(layoutMaskName); + + FailureOr distOffsetsByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy); + FailureOr distMaskByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy); + if (failed(distOffsetsByWarpOpOrFailure) || + failed(distMaskByWarpOpOrFailure)) { + loadGatherOp.emitWarning( + "Some vector operands have no layouts, using defaults instead."); + } + distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy); + distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy); SmallVector newRetIndices; SmallVector operands = loadGatherOp->getOperands(); - SmallVector operandTypesToYield = - llvm::to_vector_of(loadGatherOp->getOperandTypes()); - // Assume offset and mask producers will be distributed as well. - operandTypesToYield[1] = - VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - operandTypesToYield[2] = - VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType())); + SmallVector operandTypesToYield = {operands[0].getType(), + distOffsetsTy, distMaskTy}; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypesToYield, newRetIndices); @@ -951,6 +1004,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { xegpu::LoadGatherOp newOp = rewriter.create( newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs()); + xegpu::removeLayoutAttrs(newOp); Value distributedVal = newWarpOp.getResult(operandIdx); rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0)); return success(); @@ -990,7 +1044,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Vectors operands of these ops have a fixed and implicit layout. if (isa(op)) - continue; + continue; auto layout = xegpu::getDistributeLayoutAttrOfType(operand); if (!layout) { diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 0214d84f2c16f..cba3f0bd690c3 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) { return } +// ----- +// CHECK-LABEL: func.func @scatter_ops_chunksize( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops_chunksize(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> + : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} + +// ----- +// CHECK-LABEL: func.func @scatter_ops( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +func.func @scatter_ops(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + return +} + // ----- // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 1c4684681b62b..ddb279eb070ff 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -324,18 +324,14 @@ gpu.module @test { // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> -// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, -// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> -// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint, -// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } } @@ -344,18 +340,37 @@ gpu.module @test { // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> -// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, -// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> -// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint, -// CHECK-SAME: l2_hint = #xegpu.cache_hint}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { gpu.func @scatter_ops(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops(%src: memref<256xf16>) { + %1 = arith.constant dense<1>: vector<16xi1> + %offset = arith.constant dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 { + layout_operand_1 = #xegpu.layout, + layout_operand_2 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 { + layout_operand_0 = #xegpu.layout, + layout_operand_2 = #xegpu.layout, + layout_operand_3 = #xegpu.layout + } : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } } From bcc9d85b2c953eaf8fa6dd588a8f2aaafa6cd406 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Fri, 29 Aug 2025 16:26:33 +0000 Subject: [PATCH 5/8] Address feedback --- .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 +++++---- .../Transforms/XeGPUSubgroupDistribute.cpp | 44 +++++++------------ .../Dialect/XeGPU/subgroup-distribute.mlir | 27 ++++-------- 3 files changed, 39 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 46c5777d1c157..c0c4394f73d4a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -195,7 +195,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) { /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, - bool scattered = false) { + bool isScattered = false) { // Expecting a 1D or 2D vector. assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && "Expected 1D or 2D vector."); @@ -208,7 +208,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, // Packing factor is determined by the element type bitwidth. int packingFactor = 1; unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); - if (scattered) { + if (isScattered) { packingFactor = bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth @@ -224,7 +224,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, - bool scattered = false) { + bool isScattered = false) { // Expecting a 1D or 2D vector. assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && "Expected 1D or 2D TensorDesc."); @@ -237,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - if (scattered) { + if (isScattered) { int packingFactor = bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth @@ -558,7 +558,10 @@ void LayoutInfoPropagation::visitLoadGatherOp( ArrayRef results) { // The layout is strictly determined by the payload type. auto payloadTy = dyn_cast(load.getValueType()); - assert(payloadTy && "Only vector payload distribution is supported"); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); // Mask operand should have 1D default layout. @@ -569,9 +572,8 @@ void LayoutInfoPropagation::visitLoadGatherOp( propagateIfChanged(operands[0], operands[0]->meet(layout)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); - if (load.getOffsets()) { + if (load.getOffsets()) propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); - } } /// Propagate the layout of the descriptor to the vector offset operand in @@ -597,7 +599,10 @@ void LayoutInfoPropagation::visitStoreScatterOp( // the tensor descriptor is equal to the subgroup size. This is ensured by // the op verifier. auto payloadTy = dyn_cast(storeScatter.getValueType()); - assert(payloadTy && "Only vector payload distribution is supported"); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } auto payloadShape = payloadTy.getShape(); if (payloadShape.size() > 1) assert( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 9bb0a2160f82e..7b9c8ff3e6f6f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -849,18 +849,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { return rewriter.notifyMatchFailure(storeScatterOp, "Expected 1D offsets and mask vector"); VectorType storeVecTy = cast(storeScatterOp.getValueType()); - assert(storeVecTy.getRank() <= 2 && - "Expected at most 2D result at SG level"); - VectorType distStoreVecTy; - if (storeVecTy.getRank() == 2) - distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0); - else // rank 1 - distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1); - // Assume offset and mask producers will be distributed as well. - VectorType distOffsetsTy = - VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - VectorType distMaskTy = VectorType::get( - {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType())); + if (storeVecTy.getRank() > 2) + return rewriter.notifyMatchFailure( + storeScatterOp, "Expected at most 2D result at SG level"); + std::string layoutPayloadName = xegpu::getLayoutName(storeScatterOp->getOpOperand(0)); std::string layoutOffsetsName = @@ -884,17 +876,20 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { if (failed(distStoreVecByWarpOpOrFailure) || failed(distOffsetsByWarpOpOrFailure) || failed(distMaskByWarpOpOrFailure)) { - storeScatterOp.emitWarning( + return rewriter.notifyMatchFailure( + storeScatterOp, "Some vector operands have no layouts, using defaults instead."); } - distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy); - distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy); - distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy); + VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value(); + VectorType expectedPayloadTy = VectorType::get( + {distPayloadTy.getNumElements()}, distPayloadTy.getElementType()); SmallVector newRetIndices; SmallVector operands = storeScatterOp->getOperands(); SmallVector operandTypesToYield = { - distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy}; + expectedPayloadTy, operands[1].getType(), + distOffsetsByWarpOpOrFailure.value(), + distMaskByWarpOpOrFailure.value()}; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypesToYield, newRetIndices); @@ -958,10 +953,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { return rewriter.notifyMatchFailure(loadGatherOp, "Expected 1D offsets and mask vector"); // Assume offset and mask producers will be distributed as well. - VectorType distOffsetsTy = - VectorType::get({1}, getElementTypeOrSelf(offsetsTy)); - VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy)); - std::string layoutOffsetsName = xegpu::getLayoutName(loadGatherOp->getOpOperand(1)); std::string layoutMaskName = @@ -978,16 +969,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy); if (failed(distOffsetsByWarpOpOrFailure) || failed(distMaskByWarpOpOrFailure)) { - loadGatherOp.emitWarning( + return rewriter.notifyMatchFailure( + loadGatherOp, "Some vector operands have no layouts, using defaults instead."); } - distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy); - distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy); SmallVector newRetIndices; SmallVector operands = loadGatherOp->getOperands(); - SmallVector operandTypesToYield = {operands[0].getType(), - distOffsetsTy, distMaskTy}; + SmallVector operandTypesToYield = { + operands[0].getType(), distOffsetsByWarpOpOrFailure.value(), + distMaskByWarpOpOrFailure.value()}; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypesToYield, newRetIndices); @@ -998,7 +989,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { const unsigned operandIdx = producedByLastLoad->getOperandNumber(); VectorType loadVecTy = cast(warpOp.getResult(operandIdx).getType()); - assert(loadVecTy.getRank() == 1 && "Expected a distributed vector"); rewriter.setInsertionPointAfter(newWarpOp); xegpu::LoadGatherOp newOp = rewriter.create( diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index ddb279eb070ff..5a4030ce4bead 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -330,24 +330,15 @@ gpu.module @test { gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> - gpu.return - } -} - -// ----- -// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> -// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> -// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> -// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> -gpu.module @test { - gpu.func @scatter_ops(%src: memref<256xf16>) { - %1 = arith.constant dense<1>: vector<16xi1> - %offset = arith.constant dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_operand_1 = #xegpu.layout, + layout_operand_2 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> { + layout_operand_0 = #xegpu.layout, + layout_operand_2 = #xegpu.layout, + layout_operand_3 = #xegpu.layout + } : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } } From ffde76c8e1871bea2b1dda9960237d556f8182d4 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Mon, 1 Sep 2025 15:01:28 +0000 Subject: [PATCH 6/8] Remove exceptions --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 7b9c8ff3e6f6f..b4919932f1ce4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -1032,9 +1032,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa(operand.get().getType())) continue; - // Vectors operands of these ops have a fixed and implicit layout. - if (isa(op)) - continue; auto layout = xegpu::getDistributeLayoutAttrOfType(operand); if (!layout) { From 6bafb05a129f674b685615b674fb45c12192c0de Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Mon, 1 Sep 2025 15:22:18 +0000 Subject: [PATCH 7/8] Restructure testing --- .../Dialect/XeGPU/subgroup-distribute.mlir | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 5a4030ce4bead..a39aa90bbe3a8 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -328,17 +328,12 @@ gpu.module @test { // CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) { - %1 = arith.constant dense<1>: vector<16xi1> - %offset = arith.constant dense<12> : vector<16xindex> + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { - layout_operand_1 = #xegpu.layout, - layout_operand_2 = #xegpu.layout + layout_result_0 = #xegpu.layout } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> { - layout_operand_0 = #xegpu.layout, - layout_operand_2 = #xegpu.layout, - layout_operand_3 = #xegpu.layout - } : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } } @@ -351,17 +346,12 @@ gpu.module @test { // CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { gpu.func @scatter_ops(%src: memref<256xf16>) { - %1 = arith.constant dense<1>: vector<16xi1> - %offset = arith.constant dense<12> : vector<16xindex> + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> %3 = xegpu.load %src[%offset], %1 { - layout_operand_1 = #xegpu.layout, - layout_operand_2 = #xegpu.layout + layout_result_0 = #xegpu.layout } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset], %1 { - layout_operand_0 = #xegpu.layout, - layout_operand_2 = #xegpu.layout, - layout_operand_3 = #xegpu.layout - } : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> gpu.return } } From 2a6c7d678cfbfbfaebde7f98ceb42dc4c2d9c294 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Thu, 4 Sep 2025 16:33:17 +0000 Subject: [PATCH 8/8] Query warpOp results before moveRegion --- .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index b4919932f1ce4..6b8367dd8c201 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -980,16 +980,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern { operands[0].getType(), distOffsetsByWarpOpOrFailure.value(), distMaskByWarpOpOrFailure.value()}; + const unsigned operandIdx = producedByLastLoad->getOperandNumber(); + VectorType loadVecTy = + cast(warpOp.getResult(operandIdx).getType()); + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, operands, operandTypesToYield, newRetIndices); SmallVector newLoadGatherOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - const unsigned operandIdx = producedByLastLoad->getOperandNumber(); - VectorType loadVecTy = - cast(warpOp.getResult(operandIdx).getType()); - rewriter.setInsertionPointAfter(newWarpOp); xegpu::LoadGatherOp newOp = rewriter.create( newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,