From b94a37fb467510f4cfe9e3b902a2147f99010c8f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 18 Aug 2025 16:18:26 +0000 Subject: [PATCH 1/8] Add pattern for load_gather and store_scatter ops --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 12 +++ mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 34 ++++++ .../Transforms/XeGPUWgToSgDistribute.cpp | 100 +++++++++++++++++- 3 files changed, 145 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 480b43e740736..34429a34f6d96 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -751,6 +751,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let builders = [ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Type": $value, "Value": $source, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> @@ -859,6 +865,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let builders = [ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Value": $value, "Value": $dest, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 7b7ce19e6937b..2e9bb88174c74 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -737,6 +737,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = source.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// @@ -785,6 +801,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, l2_hint, l3_hint); } +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, + ArrayRef offsets, Value mask, + IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = dest.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 270d71aaa7273..a4d697b5357e6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -685,6 +685,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +// This pattern transforms the LoadGatherOp with explicit offsets to load +// subgroup data, similar to WgToSgLoadNdOpWithOffset. +struct WgToSgLoadGatherOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType resultType = op.getResult().getType(); + ArrayRef wgShape = resultType.getShape(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + + SmallVector newLoadOps; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); + for (auto [offsets, mask] : + llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLoadOp = rewriter.create( + loc, newTy, op.getSource(), offsets, mask, + chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + xegpu::setLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreScatterOp with explicit offsets to store +// subgroup data, similar to WgToSgStoreNdOpWithOffset. +struct WgToSgStoreScatterOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType valueType = dyn_cast(op.getValue().getType()); + if (!valueType) + return failure(); + + ArrayRef wgShape = valueType.getShape(); + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); + if (!layout || !layout.getSgLayout()) + return failure(); + + auto chunkSizeOpt = op.getChunkSize(); + int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); + for (auto [val, offs, mask] : llvm::zip( + adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + rewriter.create( + loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + // Update the layout_result_0 attribute to drop sg_layout and sg_data. + if (auto layoutAttr = + op->getAttrOfType("layout_result_0")) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) + op->setAttr("layout_result_0", newLayout); + } + } + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace namespace mlir { @@ -694,7 +776,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, - WgToSgConvertLayoutOp, WgToSgArithConstantOp>( + WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, + WgToSgStoreScatterOpWithOffset>( patterns.getContext()); } } // namespace xegpu @@ -815,6 +898,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](xegpu::LoadGatherOp op) -> bool { + auto layout = xegpu::getLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp( + [=](xegpu::StoreScatterOp op) -> bool { + // Check if the layout attribute is present on the result. + auto layout = op->getAttrOfType("layout_result_0"); + if (!layout) + return true; + return isLegal(layout); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); From 459e98ab8e4f770f4832873d9adc4d97035c9933 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 19 Aug 2025 17:26:10 +0000 Subject: [PATCH 2/8] Add tests --- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 07a0b86223c33..03fd95f0e778e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -263,4 +263,45 @@ gpu.module @test_distribution { } {sg_id_range = #xegpu.range<[3, 19]>} gpu.return } + + // CHECK-LABEL: @load_gather + // CHECK-SAME: %[[ARG0:.*]]: memref + gpu.func @load_gather(%src : memref) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<32x4xi1> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256x16xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256x16xi1> + %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : memref, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16> + gpu.return + } + + // CHECK-LABEL: @store_scatter + // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16> + gpu.func @store_scatter(%dest : memref<256xf16>) { + // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16> + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> + // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> + %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> + gpu.return + } + + // CHECK-LABEL: @load_with_chunk_size + // CHECK-SAME: %[[ARG0:.*]]: memref + gpu.func @load_with_chunk_size(%src : memref) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<8xindex>, vector<8xi1> -> vector<8x4xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> + %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + : memref, vector<256xindex>, vector<256xi1> -> vector<256x4xf16> + gpu.return + } } From a25c40deb19866819c0a8efa61e30d878111826c Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 28 Aug 2025 16:38:55 +0000 Subject: [PATCH 3/8] Feedback --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 16 ++++++---------- .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 15 +++++++++------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 367d93b37c614..90084b4395355 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -780,7 +780,7 @@ struct WgToSgLoadGatherOpWithOffset ArrayRef wgShape = resultType.getShape(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -820,9 +820,8 @@ struct WgToSgStoreScatterOpWithOffset if (!valueType) return failure(); - ArrayRef wgShape = valueType.getShape(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); - if (!layout || !layout.getSgLayout()) + if (!layout || !layout.isForWorkgroup()) return failure(); auto chunkSizeOpt = op.getChunkSize(); @@ -833,12 +832,9 @@ struct WgToSgStoreScatterOpWithOffset rewriter.create( loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - // Update the layout_result_0 attribute to drop sg_layout and sg_data. - if (auto layoutAttr = - op->getAttrOfType("layout_result_0")) { - if (auto newLayout = layoutAttr.dropSgLayoutAndData()) - op->setAttr("layout_result_0", newLayout); - } + // Update the layout attribute to drop sg_layout and sg_data. + if (auto newLayout = layout.dropSgLayoutAndData()) + op->setAttr("layout", newLayout); } rewriter.eraseOp(op); return success(); @@ -1042,7 +1038,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp( [=](xegpu::StoreScatterOp op) -> bool { // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType("layout_result_0"); + auto layout = op->getAttrOfType("layout"); if (!layout) return true; return isLegal(layout); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 90e36a4c994b9..afb2bf876c18f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -269,7 +269,8 @@ gpu.module @test_distribution { gpu.func @load_gather(%src : memref) { // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<32x4xi1> - // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : memref, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256x16xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256x16xi1> %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} @@ -283,21 +284,23 @@ gpu.module @test_distribution { // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16> // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> - // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> + // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> gpu.return } - // CHECK-LABEL: @load_with_chunk_size + // CHECK-LABEL: @load_with_non_unit_chunk_size // CHECK-SAME: %[[ARG0:.*]]: memref - gpu.func @load_with_chunk_size(%src : memref) { + gpu.func @load_with_non_unit_chunk_size(%src : memref) { // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> - // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<8xindex>, vector<8xi1> -> vector<8x4xf16> + // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: : memref, vector<8xindex>, vector<8xi1> -> vector<8x4xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} From bdbf14f0c1bcbea52863da47a1ef84f1c8024013 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 28 Aug 2025 16:45:09 +0000 Subject: [PATCH 4/8] Cleanup --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 90084b4395355..3ed177e398836 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -764,7 +764,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { }; // This pattern transforms the LoadGatherOp with explicit offsets to load -// subgroup data, similar to WgToSgLoadNdOpWithOffset. +// subgroup data struct WgToSgLoadGatherOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -804,7 +804,7 @@ struct WgToSgLoadGatherOpWithOffset }; // This pattern transforms the StoreScatterOp with explicit offsets to store -// subgroup data, similar to WgToSgStoreNdOpWithOffset. +// subgroup data struct WgToSgStoreScatterOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From c93090f2f0853aa7047d90bc58aba9d1be518a92 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 28 Aug 2025 17:21:24 +0000 Subject: [PATCH 5/8] Add check --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3ed177e398836..90848eebd1243 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -785,6 +785,14 @@ struct WgToSgLoadGatherOpWithOffset SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + // The offsets need to be distributed + if (dyn_cast(adaptor.getOffsets().front().getType()) + .getShape() != + dyn_cast(adaptor.getMask().front().getType()).getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + SmallVector newLoadOps; auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); @@ -824,6 +832,14 @@ struct WgToSgStoreScatterOpWithOffset if (!layout || !layout.isForWorkgroup()) return failure(); + // The offsets need to be distributed + if (dyn_cast(adaptor.getOffsets().front().getType()) + .getShape() != + dyn_cast(adaptor.getMask().front().getType()).getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + auto chunkSizeOpt = op.getChunkSize(); int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); From a7b780d172be99f9b3b0672640e0aa27e958bd86 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 29 Aug 2025 17:50:03 +0000 Subject: [PATCH 6/8] Feedback --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 90848eebd1243..623cac66e90f3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -786,9 +786,12 @@ struct WgToSgLoadGatherOpWithOffset SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; // The offsets need to be distributed - if (dyn_cast(adaptor.getOffsets().front().getType()) - .getShape() != - dyn_cast(adaptor.getMask().front().getType()).getShape()) { + auto offsetsVecType = + dyn_cast(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } @@ -833,9 +836,12 @@ struct WgToSgStoreScatterOpWithOffset return failure(); // The offsets need to be distributed - if (dyn_cast(adaptor.getOffsets().front().getType()) - .getShape() != - dyn_cast(adaptor.getMask().front().getType()).getShape()) { + auto offsetsVecType = + dyn_cast(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } From 21f1f4fbe2e12ea936ff9df391793785a084ff7b Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 29 Aug 2025 17:59:05 +0000 Subject: [PATCH 7/8] Add check --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 623cac66e90f3..02f7bd74b472f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -776,7 +776,9 @@ struct WgToSgLoadGatherOpWithOffset return failure(); Location loc = op.getLoc(); - VectorType resultType = op.getResult().getType(); + VectorType resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return failure(); ArrayRef wgShape = resultType.getShape(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); From efc211bc6f69e826530833fa0944142b855d60e4 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Sat, 30 Aug 2025 20:55:59 +0000 Subject: [PATCH 8/8] use updated api's for layout attr --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 74dac59208514..9f627c7e1e6d8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -783,7 +783,8 @@ struct WgToSgLoadGatherOpWithOffset return failure(); ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -809,8 +810,8 @@ struct WgToSgLoadGatherOpWithOffset auto newLoadOp = rewriter.create( loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - xegpu::setLayoutAttr(newLoadOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -835,7 +836,8 @@ struct WgToSgStoreScatterOpWithOffset if (!valueType) return failure(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getValue()); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -1057,7 +1059,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp( [=](xegpu::LoadGatherOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); });