From 671215094dde611017c0f6c16e9665935820d462 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 26 Sep 2025 17:42:27 +0000 Subject: [PATCH 1/9] Add support for non splatable constant --- .../Transforms/XeGPUWgToSgDistribute.cpp | 123 +++++++++++++++--- 1 file changed, 107 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a9296b184..8705f4aca0dd1 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -711,7 +711,6 @@ struct UnrealizedConversionCastOpPattern } }; -// This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -720,7 +719,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast(op.getValue()); auto vecType = dyn_cast(op.getType()); - if (!vecAttr || !vecAttr.isSplat() || !vecType) + if (!vecAttr || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = @@ -733,22 +732,114 @@ struct WgToSgArithConstantOp : public OpConversionPattern { int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - // Current limitation: constant of vector with single value. - // TODO: support more complex cases, e.g., vector with multiple values. - Attribute singleVal = vecAttr.getSplatValue(); - auto newType = VectorType::get(sgShape, vecType.getElementType()); - auto sgAttr = DenseElementsAttr::get(newType, singleVal); - auto cstOp = - arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), - layout.dropSgLayoutAndData()); - SmallVector newConsts(count, cstOp); + Location loc = op.getLoc(); + auto eltType = vecType.getElementType(); - rewriter.replaceOpWithMultiple(op, {newConsts}); - return success(); + auto setLayoutIfNeeded = [&](Value val) { + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast(val), + layout.dropSgLayoutAndData()); + } + }; + + if (vecAttr.isSplat()) { + // Splat: single value for all subgroups + Attribute singleVal = vecAttr.getSplatValue(); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); + setLayoutIfNeeded(cstOp->getResult(0)); + rewriter.replaceOp(op, cstOp); + return success(); + } else if (sgShape == wgShape) { // if the entire vector is shared by all + // threads...don't distribute + auto newConstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); + setLayoutIfNeeded(newConstOp->getResult(0)); + rewriter.replaceOp(op, newConstOp); + return success(); + } else { + // Non-splat constant: use baseValue/stride logic for runtime indexing, + // with wrap-around + if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1) + return rewriter.notifyMatchFailure( + op, "Only 1D or 2D vector constant supported"); + SmallVector values(vecAttr.getValues()); + int64_t stride = 0; + if (values.size() > 1) { + stride = cast(values[1]).getInt() - + cast(values[0]).getInt(); + for (size_t i = 2; i < values.size(); ++i) { + int64_t diff = cast(values[i]).getInt() - + cast(values[i - 1]).getInt(); + if (diff != stride) + return rewriter.notifyMatchFailure( + op, "Non-constant stride in non-splat constant op."); + } + } + + // Create a constant for the first tile + SmallVector tileValues; + int sgData = 1; + if (sgShape.size() == 1) { + sgData = static_cast(sgShape[0]); + } else if (sgShape.size() == 2) { + // If shape is [1, n] or [n, 1], pick the non-1 dimension (n). + if (sgShape[0] == 1 && sgShape[1] != 1) + sgData = static_cast(sgShape[1]); + else + sgData = static_cast(sgShape[0]); + } else { + return rewriter.notifyMatchFailure( + op, "Only 1D or 2D vector constant supported"); + } + + for (int i = 0; i < sgData; ++i) + tileValues.push_back(values[i]); + auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType), + tileValues); + auto baseConstVec = rewriter.create(loc, tileAttr); + + // Get subgroup/thread id + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData + int64_t nonUnitDim = 0; + if (wgShape.size() == 2) + nonUnitDim = wgShape[0] != 1 ? 0 : 1; + // For 1D, just use the first dim + int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim]; + auto numTileConst = + rewriter.create(loc, numTiles); + Value remsiOp = rewriter.create( + loc, rewriter.getIndexType(), sgId, numTileConst); + auto baseValueConst = + rewriter.create(loc, stride * sgData); + Value baseValue = rewriter.create( + loc, rewriter.getIndexType(), remsiOp, baseValueConst); + + // Broadcast baseValue to vector + auto splatBaseValue = rewriter.create( + loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue); + + // Add baseValue to baseConstantVec constant + Value finalTile = rewriter.create( + loc, splatBaseValue->getResult(0), baseConstVec); + + // Cast to final type if needed + Value result; + if (eltType.isIndex()) { + result = finalTile; + } else { + result = rewriter.create(loc, newType, finalTile); + } + + setLayoutIfNeeded(result); + rewriter.replaceOp(op, result); + return success(); + } } }; From 7d3746a09bc77c997382cb137c3a8b2326d28c6f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 30 Sep 2025 06:00:50 +0000 Subject: [PATCH 2/9] Support 1:N conversion --- .../Transforms/XeGPUWgToSgDistribute.cpp | 72 +++++++++---------- .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 28 ++++++++ .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 20 ++++++ 3 files changed, 80 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8705f4aca0dd1..2bbf1a85bb5be 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -753,18 +753,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern { rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all - // threads...don't distribute + // subgroups...don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); setLayoutIfNeeded(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { - // Non-splat constant: use baseValue/stride logic for runtime indexing, - // with wrap-around - if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1) + // Non-splat constant + if (wgShape.size() > 2) return rewriter.notifyMatchFailure( - op, "Only 1D or 2D vector constant supported"); + op, "Only 1D & 2D vector constant supported"); + + if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1) + return rewriter.notifyMatchFailure( + op, "2D vector constant only supported with 1 unit dim"); + SmallVector values(vecAttr.getValues()); int64_t stride = 0; if (values.size() > 1) { @@ -779,13 +783,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } } - // Create a constant for the first tile + // Create a constant for the base tile SmallVector tileValues; int sgData = 1; if (sgShape.size() == 1) { sgData = static_cast(sgShape[0]); } else if (sgShape.size() == 2) { - // If shape is [1, n] or [n, 1], pick the non-1 dimension (n). + // If shape is [1, n] or [n, 1], pick the non-unit dimension. if (sgShape[0] == 1 && sgShape[1] != 1) sgData = static_cast(sgShape[1]); else @@ -801,43 +805,31 @@ struct WgToSgArithConstantOp : public OpConversionPattern { tileValues); auto baseConstVec = rewriter.create(loc, tileAttr); - // Get subgroup/thread id + // Get subgroup id Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData - int64_t nonUnitDim = 0; - if (wgShape.size() == 2) - nonUnitDim = wgShape[0] != 1 ? 0 : 1; - // For 1D, just use the first dim - int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim]; - auto numTileConst = - rewriter.create(loc, numTiles); - Value remsiOp = rewriter.create( - loc, rewriter.getIndexType(), sgId, numTileConst); - auto baseValueConst = - rewriter.create(loc, stride * sgData); - Value baseValue = rewriter.create( - loc, rewriter.getIndexType(), remsiOp, baseValueConst); - - // Broadcast baseValue to vector - auto splatBaseValue = rewriter.create( - loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue); - - // Add baseValue to baseConstantVec constant - Value finalTile = rewriter.create( - loc, splatBaseValue->getResult(0), baseConstVec); - - // Cast to final type if needed - Value result; - if (eltType.isIndex()) { - result = finalTile; - } else { - result = rewriter.create(loc, newType, finalTile); + auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector newConstOps; + for (auto offsets : *sgOffsets) { + // Multiply offset with stride and broadcast it to a vector of + // "sgData[nonUnitDim]" size + auto strideConst = rewriter.create(loc, stride); + Value mulOffset = rewriter.create( + loc, rewriter.getIndexType(), offsets[0], strideConst); + auto bcastOffset = rewriter.create( + loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset); + auto finalConst = + arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); + setLayoutIfNeeded(baseConstVec); + setLayoutIfNeeded(bcastOffset); + setLayoutIfNeeded(finalConst); + newConstOps.push_back(finalConst); } - - setLayoutIfNeeded(result); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithMultiple(op, {newConstOps}); return success(); } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index dce73dee507e1..271d2b2f908fb 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -98,4 +98,32 @@ gpu.module @test_distribution { : vector<256x64xf32> to vector<256xf32> gpu.return } + + gpu.func @non_splat_constant() { + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex> + // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]] + // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]] + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]] + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index + // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]] + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index + // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]] + // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index + // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex> + // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index + // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex> + // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 48fc633974e63..07b1e0f9ba8db 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -464,4 +464,24 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout} : index to vector<4x2x6x32xindex> gpu.return } + + // CHECK-LABEL: non_splat_constant + gpu.func @non_splat_constant() { + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]] + // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]] + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index + // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]] + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index + // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex> + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> + gpu.return + } } From 1b00dc76ebca1a475d1345387fd3edef0c34b659 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 30 Sep 2025 16:49:26 +0000 Subject: [PATCH 3/9] All cases work --- .../Transforms/XeGPUWgToSgDistribute.cpp | 43 ++++++++++++------- .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 3 +- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 2bbf1a85bb5be..be03e6e050c43 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -711,6 +711,7 @@ struct UnrealizedConversionCastOpPattern } }; +// This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -753,7 +754,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all - // subgroups...don't distribute + // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); setLayoutIfNeeded(newConstOp->getResult(0)); @@ -761,13 +762,28 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return success(); } else { // Non-splat constant + // Only supports 1D & 2D (with one unit dim) + // TODO: support other cases that require SLM access + if (!eltType.isIndex()) + return rewriter.notifyMatchFailure( + op, "Unsupported element type for non-splat constant op."); + + SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); if (wgShape.size() > 2) return rewriter.notifyMatchFailure( op, "Only 1D & 2D vector constant supported"); - if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1) + // allow 2D vector/distributions with one unit dim + auto hasTwoNonUnitDims = [](ArrayRef dims) { + return dims.size() == 2 && dims[0] != 1 && dims[1] != 1; + }; + if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout)) return rewriter.notifyMatchFailure( - op, "2D vector constant only supported with 1 unit dim"); + op, "2D vector/distribution only supported with 1 unit dim"); + + int64_t nonUnitDim = 0; + if (wgShape.size() == 2) + nonUnitDim = wgShape[0] != 1 ? 0 : 1; SmallVector values(vecAttr.getValues()); int64_t stride = 0; @@ -783,26 +799,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } } - // Create a constant for the base tile - SmallVector tileValues; int sgData = 1; if (sgShape.size() == 1) { sgData = static_cast(sgShape[0]); } else if (sgShape.size() == 2) { - // If shape is [1, n] or [n, 1], pick the non-unit dimension. - if (sgShape[0] == 1 && sgShape[1] != 1) - sgData = static_cast(sgShape[1]); - else - sgData = static_cast(sgShape[0]); + sgData = static_cast(sgShape[0] != 1 ? sgShape[0] : sgShape[1]); } else { return rewriter.notifyMatchFailure( op, "Only 1D or 2D vector constant supported"); } + // Create a constant for the base tile + SmallVector baseTileValues; for (int i = 0; i < sgData; ++i) - tileValues.push_back(values[i]); + baseTileValues.push_back(values[i]); auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType), - tileValues); + baseTileValues); auto baseConstVec = rewriter.create(loc, tileAttr); // Get subgroup id @@ -813,13 +825,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern { if (failed(sgOffsets)) return failure(); + auto strideConst = rewriter.create(loc, stride); SmallVector newConstOps; for (auto offsets : *sgOffsets) { - // Multiply offset with stride and broadcast it to a vector of - // "sgData[nonUnitDim]" size - auto strideConst = rewriter.create(loc, stride); + // Multiply offset with stride, broadcast it and add to baseConstVec Value mulOffset = rewriter.create( - loc, rewriter.getIndexType(), offsets[0], strideConst); + loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst); auto bcastOffset = rewriter.create( loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset); auto finalConst = diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 271d2b2f908fb..f3e2e41ae4b65 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -115,12 +115,11 @@ gpu.module @test_distribution { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]] // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex> - // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index + // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex> // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex> %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> From 1381174b29812c6db58c46038aba2b718d9c9072 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 30 Sep 2025 20:33:23 +0000 Subject: [PATCH 4/9] Fix CHECKS --- .../Transforms/XeGPUWgToSgDistribute.cpp | 2 +- .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 34 ++++++++++++------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 13 ++++--- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index be03e6e050c43..9807cb98a5a83 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { // Multiply offset with stride, broadcast it and add to baseConstVec Value mulOffset = rewriter.create( loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst); - auto bcastOffset = rewriter.create( + auto bcastOffset = rewriter.create( loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index f3e2e41ae4b65..9958d4ef6c1e2 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -102,26 +102,34 @@ gpu.module @test_distribution { gpu.func @non_splat_constant() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex> // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]] // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]] // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]] + // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index - // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index + // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]] - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]] + // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]] // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index - // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]] - // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index - // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex> - // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index - // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex> - // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex> + // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index + // CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]] + // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]] + // CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex> + // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex> + // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index + // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex> + // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex> %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 85b78bb41db08..a2203f8e945d2 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -463,18 +463,17 @@ gpu.module @test_distribution { gpu.func @non_splat_constant() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]] // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]] // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index - // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index - // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]] - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]] + // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]] // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index - // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex> + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex> + // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex> %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } From e77edddfc6ae346d0537f238c74b1b7524ec163e Mon Sep 17 00:00:00 2001 From: nbpatel Date: Sun, 5 Oct 2025 21:55:42 +0000 Subject: [PATCH 5/9] Support 2D case --- .../Transforms/XeGPUWgToSgDistribute.cpp | 129 +++++++++++++----- .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 50 +++---- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 22 ++- 3 files changed, 127 insertions(+), 74 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9807cb98a5a83..b7107011ee178 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -773,48 +773,96 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only 1D & 2D vector constant supported"); - // allow 2D vector/distributions with one unit dim - auto hasTwoNonUnitDims = [](ArrayRef dims) { - return dims.size() == 2 && dims[0] != 1 && dims[1] != 1; - }; - if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout)) - return rewriter.notifyMatchFailure( - op, "2D vector/distribution only supported with 1 unit dim"); - - int64_t nonUnitDim = 0; - if (wgShape.size() == 2) - nonUnitDim = wgShape[0] != 1 ? 0 : 1; - SmallVector values(vecAttr.getValues()); int64_t stride = 0; - if (values.size() > 1) { - stride = cast(values[1]).getInt() - - cast(values[0]).getInt(); - for (size_t i = 2; i < values.size(); ++i) { - int64_t diff = cast(values[i]).getInt() - - cast(values[i - 1]).getInt(); - if (diff != stride) - return rewriter.notifyMatchFailure( - op, "Non-constant stride in non-splat constant op."); + int64_t rowStride = 0, colStride = 0; + if (wgShape.size() == 1) { + // 1D case: single stride + if (values.size() > 1) { + stride = cast(values[1]).getInt() - + cast(values[0]).getInt(); + for (size_t i = 2; i < values.size(); ++i) { + int64_t diff = cast(values[i]).getInt() - + cast(values[i - 1]).getInt(); + if (diff != stride) + return rewriter.notifyMatchFailure( + op, "Non-constant stride in non-splat constant op."); + } + } + } else if (wgShape.size() == 2) { + // 2D case: row stride and column stride + int64_t rows = wgShape[0], cols = wgShape[1]; + if (values.size() != static_cast(rows * cols)) + return rewriter.notifyMatchFailure( + op, "Mismatch between vector shape and constant values size."); + // Compute col stride (stride between elements in a column) + if (cols > 1) { + colStride = cast(values[1]).getInt() - + cast(values[0]).getInt(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 1; c < cols; ++c) { + int64_t idx = r * cols + c; + int64_t prevIdx = r * cols + (c - 1); + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != colStride) + return rewriter.notifyMatchFailure( + op, "Non-constant column stride in 2D constant op."); + } + } + } + // Compute row stride (stride between elements in a row) + if (rows > 1) { + rowStride = cast(values[cols]).getInt() - + cast(values[0]).getInt(); + for (int64_t c = 0; c < cols; ++c) { + for (int64_t r = 1; r < rows; ++r) { + int64_t idx = r * cols + c; + int64_t prevIdx = (r - 1) * cols + c; + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != rowStride) + return rewriter.notifyMatchFailure( + op, "Non-constant row stride in 2D constant op."); + } + } } } - int sgData = 1; + // Determine the shape of the base tile for each subgroup. + SmallVector baseTileShape; if (sgShape.size() == 1) { - sgData = static_cast(sgShape[0]); + baseTileShape.push_back(sgShape[0]); } else if (sgShape.size() == 2) { - sgData = static_cast(sgShape[0] != 1 ? sgShape[0] : sgShape[1]); + baseTileShape = sgShape; } else { return rewriter.notifyMatchFailure( op, "Only 1D or 2D vector constant supported"); } - // Create a constant for the base tile + // Compute the number of elements in the base tile. + int64_t baseTileElemCount = 1; + for (int64_t d : baseTileShape) + baseTileElemCount *= d; + + // Create a constant for the base tile. + // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. SmallVector baseTileValues; - for (int i = 0; i < sgData; ++i) - baseTileValues.push_back(values[i]); - auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType), - baseTileValues); + if (baseTileShape.size() == 2) { + int64_t rows = baseTileShape[0], cols = baseTileShape[1]; + int64_t wgRows = wgShape[0], wgCols = wgShape[1]; + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + baseTileValues.push_back(values[r * wgCols + c]); + } + } + } else { + // 1D case + for (int64_t i = 0; i < baseTileElemCount; ++i) + baseTileValues.push_back(values[i]); + } + auto tileAttr = DenseElementsAttr::get( + VectorType::get(baseTileShape, eltType), baseTileValues); auto baseConstVec = rewriter.create(loc, tileAttr); // Get subgroup id @@ -826,13 +874,30 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return failure(); auto strideConst = rewriter.create(loc, stride); + auto strideConstRow = + rewriter.create(loc, rowStride); + auto strideConstCol = + rewriter.create(loc, colStride); SmallVector newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec - Value mulOffset = rewriter.create( - loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst); + Value mulOffset; + if (baseTileShape.size() == 1) { + // 1D: offset[0] * strideConst + mulOffset = rewriter.create( + loc, rewriter.getIndexType(), offsets[0], strideConst); + } else if (baseTileShape.size() == 2) { + // 2D: offset[0]*strideConstRow + offset[1]*strideConstCol + Value rowMul = rewriter.create( + loc, rewriter.getIndexType(), offsets[0], strideConstRow); + Value colMul = rewriter.create( + loc, rewriter.getIndexType(), offsets[1], strideConstCol); + mulOffset = rewriter.create( + loc, rewriter.getIndexType(), rowMul, colMul); + } + // Broadcast to baseConstVec size auto bcastOffset = rewriter.create( - loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset); + loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); setLayoutIfNeeded(baseConstVec); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 9958d4ef6c1e2..c2e51bdb71485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -100,36 +100,26 @@ gpu.module @test_distribution { } gpu.func @non_splat_constant() { - // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex> - // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]] - // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]] - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]] - // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]] - // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]] - // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index - // CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]] - // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]] - // CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index - // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex> - // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex> - // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index - // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex> - // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex> + // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]] + // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]] + // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]] + // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index + // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]] + // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]] + // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[STRIDE1]], %[[STRIDE2]] : index + // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex> + // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex> + // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[STRIDE3]], %[[STRIDE4]] : index + // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES2]] : index to vector<2x1xindex> + // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex> %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index a2203f8e945d2..51158fa11a9ec 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -461,19 +461,17 @@ gpu.module @test_distribution { // CHECK-LABEL: non_splat_constant gpu.func @non_splat_constant() { - // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]] - // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]] - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]] - // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]] - // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index - // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex> - // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex> + // CHECK-DAG: affine.apply #map4()[%[[SGID]]] + // CHECK-DAG: affine.apply #map5()[%[[SGID]]] + // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]] + // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]] + // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[STRIDECOL]], %[[STRIDEROW]] : index + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex> + // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex> %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } From 1b779b7d39cfb3650f0eec1dd6c5b7bace1dd4a9 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 6 Oct 2025 05:43:07 +0000 Subject: [PATCH 6/9] Add 2D test case --- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 51158fa11a9ec..5f990a49f1298 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -475,4 +475,34 @@ gpu.module @test_distribution { %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } + + // CHECK-LABEL: non_splat_constant_2D_non_unit_dim + gpu.func @non_splat_constant_2D_non_unit_dim() { + // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]] + // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]] + // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]] + // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]] + // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]] + // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index + // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]] + // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index + // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index + // CHECK-DAG: %[[ADDIDX:.*]] = arith.addi %[[MUL5]], %[[MUL6]] : index + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDIDX]] : index to vector<2x2xindex> + // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex> + %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout} dense<[ + [0, 16, 32, 48, 64, 80, 96, 112], + [8, 24, 40, 56, 72, 88, 104, 120], + [16, 32, 48, 64, 80, 96, 112, 128], + [24, 40, 56, 72, 88, 104, 120, 136], + [32, 48, 64, 80, 96, 112, 128, 144], + [40, 56, 72, 88, 104, 120, 136, 152], + [48, 64, 80, 96, 112, 128, 144, 160], + [56, 72, 88, 104, 120, 136, 152, 168] + ]> : vector<8x8xindex> + gpu.return + } } From 1b8db0e54a0a6d5d3e1258f3e9793ab2065e03fe Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 6 Oct 2025 05:47:56 +0000 Subject: [PATCH 7/9] Clean up --- .../Transforms/XeGPUWgToSgDistribute.cpp | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index b7107011ee178..2862400c85cca 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -762,13 +762,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return success(); } else { // Non-splat constant - // Only supports 1D & 2D (with one unit dim) + // Only supports 1D & 2D // TODO: support other cases that require SLM access if (!eltType.isIndex()) return rewriter.notifyMatchFailure( op, "Unsupported element type for non-splat constant op."); - SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); if (wgShape.size() > 2) return rewriter.notifyMatchFailure( op, "Only 1D & 2D vector constant supported"); @@ -792,9 +791,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } else if (wgShape.size() == 2) { // 2D case: row stride and column stride int64_t rows = wgShape[0], cols = wgShape[1]; - if (values.size() != static_cast(rows * cols)) - return rewriter.notifyMatchFailure( - op, "Mismatch between vector shape and constant values size."); // Compute col stride (stride between elements in a column) if (cols > 1) { colStride = cast(values[1]).getInt() - @@ -840,17 +836,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern { op, "Only 1D or 2D vector constant supported"); } - // Compute the number of elements in the base tile. - int64_t baseTileElemCount = 1; - for (int64_t d : baseTileShape) - baseTileElemCount *= d; - // Create a constant for the base tile. // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. SmallVector baseTileValues; if (baseTileShape.size() == 2) { int64_t rows = baseTileShape[0], cols = baseTileShape[1]; - int64_t wgRows = wgShape[0], wgCols = wgShape[1]; + int64_t wgCols = wgShape[1]; for (int64_t r = 0; r < rows; ++r) { for (int64_t c = 0; c < cols; ++c) { baseTileValues.push_back(values[r * wgCols + c]); @@ -858,7 +849,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } } else { // 1D case - for (int64_t i = 0; i < baseTileElemCount; ++i) + for (int64_t i = 0; i < computeProduct(baseTileShape); ++i) baseTileValues.push_back(values[i]); } auto tileAttr = DenseElementsAttr::get( @@ -874,24 +865,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return failure(); auto strideConst = rewriter.create(loc, stride); - auto strideConstRow = + auto rowStrideConst = rewriter.create(loc, rowStride); - auto strideConstCol = + auto colStrideConst = rewriter.create(loc, colStride); SmallVector newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec Value mulOffset; - if (baseTileShape.size() == 1) { + if (wgShape.size() == 1) { // 1D: offset[0] * strideConst mulOffset = rewriter.create( loc, rewriter.getIndexType(), offsets[0], strideConst); - } else if (baseTileShape.size() == 2) { - // 2D: offset[0]*strideConstRow + offset[1]*strideConstCol + } else if (wgShape.size() == 2) { + // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst Value rowMul = rewriter.create( - loc, rewriter.getIndexType(), offsets[0], strideConstRow); + loc, rewriter.getIndexType(), offsets[0], rowStrideConst); Value colMul = rewriter.create( - loc, rewriter.getIndexType(), offsets[1], strideConstCol); + loc, rewriter.getIndexType(), offsets[1], colStrideConst); mulOffset = rewriter.create( loc, rewriter.getIndexType(), rowMul, colMul); } From fabb41919b3ac7a24e2193e01c80f93d2933636e Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 6 Oct 2025 22:44:27 +0000 Subject: [PATCH 8/9] Clean up --- .../Transforms/XeGPUWgToSgDistribute.cpp | 65 +++++++------------ 1 file changed, 24 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 2862400c85cca..dd9f50967534a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -825,35 +825,21 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } } - // Determine the shape of the base tile for each subgroup. - SmallVector baseTileShape; - if (sgShape.size() == 1) { - baseTileShape.push_back(sgShape[0]); - } else if (sgShape.size() == 2) { - baseTileShape = sgShape; - } else { - return rewriter.notifyMatchFailure( - op, "Only 1D or 2D vector constant supported"); - } - // Create a constant for the base tile. // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. + // For 1D case, extract the first sgShape[0] elements. SmallVector baseTileValues; - if (baseTileShape.size() == 2) { - int64_t rows = baseTileShape[0], cols = baseTileShape[1]; - int64_t wgCols = wgShape[1]; - for (int64_t r = 0; r < rows; ++r) { - for (int64_t c = 0; c < cols; ++c) { - baseTileValues.push_back(values[r * wgCols + c]); - } + int cols = sgShape[sgShape.size() - 1]; + int64_t wgCols = wgShape[sgShape.size() - 1]; + int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0]; + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + baseTileValues.push_back(values[r * wgCols + c]); } - } else { - // 1D case - for (int64_t i = 0; i < computeProduct(baseTileShape); ++i) - baseTileValues.push_back(values[i]); } - auto tileAttr = DenseElementsAttr::get( - VectorType::get(baseTileShape, eltType), baseTileValues); + + auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), + baseTileValues); auto baseConstVec = rewriter.create(loc, tileAttr); // Get subgroup id @@ -864,27 +850,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern { if (failed(sgOffsets)) return failure(); - auto strideConst = rewriter.create(loc, stride); - auto rowStrideConst = - rewriter.create(loc, rowStride); - auto colStrideConst = - rewriter.create(loc, colStride); + SmallVector strideConsts; + strideConsts.push_back( + rewriter.create(loc, rowStride)); + strideConsts.push_back( + rewriter.create(loc, colStride)); SmallVector newConstOps; + Value mulOffset; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec - Value mulOffset; - if (wgShape.size() == 1) { - // 1D: offset[0] * strideConst - mulOffset = rewriter.create( - loc, rewriter.getIndexType(), offsets[0], strideConst); - } else if (wgShape.size() == 2) { - // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst - Value rowMul = rewriter.create( - loc, rewriter.getIndexType(), offsets[0], rowStrideConst); - Value colMul = rewriter.create( - loc, rewriter.getIndexType(), offsets[1], colStrideConst); + SmallVector muls; + for (size_t i = 0; i < strideConsts.size(); ++i) { + muls.push_back(rewriter.create( + loc, rewriter.getIndexType(), offsets[i], strideConsts[i])); + } + mulOffset = muls.front(); + if (muls.size() > 1) { mulOffset = rewriter.create( - loc, rewriter.getIndexType(), rowMul, colMul); + loc, rewriter.getIndexType(), mulOffset, muls[1]); } // Broadcast to baseConstVec size auto bcastOffset = rewriter.create( From 2c81deeb011b1f3e4e7a97c731f715c0b4b6d9f8 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 7 Oct 2025 00:32:18 +0000 Subject: [PATCH 9/9] Refactor --- .../Transforms/XeGPUWgToSgDistribute.cpp | 94 ++++++++----------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 4 +- 2 files changed, 43 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index dd9f50967534a..659039b41638d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -773,54 +773,40 @@ struct WgToSgArithConstantOp : public OpConversionPattern { op, "Only 1D & 2D vector constant supported"); SmallVector values(vecAttr.getValues()); - int64_t stride = 0; int64_t rowStride = 0, colStride = 0; - if (wgShape.size() == 1) { - // 1D case: single stride - if (values.size() > 1) { - stride = cast(values[1]).getInt() - - cast(values[0]).getInt(); - for (size_t i = 2; i < values.size(); ++i) { - int64_t diff = cast(values[i]).getInt() - - cast(values[i - 1]).getInt(); - if (diff != stride) + int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; + int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; + + // Compute colStride and rowStride, and check for constant strides. + if (cols > 1) { + colStride = cast(values[1]).getInt() - + cast(values[0]).getInt(); + } + if (rows > 1) { + rowStride = cast(values[cols]).getInt() - + cast(values[0]).getInt(); + } + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + int64_t idx = r * cols + c; + // Check column stride (skip first column) + if (c > 0 && cols > 1) { + int64_t prevIdx = r * cols + (c - 1); + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != colStride) return rewriter.notifyMatchFailure( - op, "Non-constant stride in non-splat constant op."); - } - } - } else if (wgShape.size() == 2) { - // 2D case: row stride and column stride - int64_t rows = wgShape[0], cols = wgShape[1]; - // Compute col stride (stride between elements in a column) - if (cols > 1) { - colStride = cast(values[1]).getInt() - - cast(values[0]).getInt(); - for (int64_t r = 0; r < rows; ++r) { - for (int64_t c = 1; c < cols; ++c) { - int64_t idx = r * cols + c; - int64_t prevIdx = r * cols + (c - 1); - int64_t diff = cast(values[idx]).getInt() - - cast(values[prevIdx]).getInt(); - if (diff != colStride) - return rewriter.notifyMatchFailure( - op, "Non-constant column stride in 2D constant op."); - } + op, "Non-constant column stride in constant op."); } - } - // Compute row stride (stride between elements in a row) - if (rows > 1) { - rowStride = cast(values[cols]).getInt() - - cast(values[0]).getInt(); - for (int64_t c = 0; c < cols; ++c) { - for (int64_t r = 1; r < rows; ++r) { - int64_t idx = r * cols + c; - int64_t prevIdx = (r - 1) * cols + c; - int64_t diff = cast(values[idx]).getInt() - - cast(values[prevIdx]).getInt(); - if (diff != rowStride) - return rewriter.notifyMatchFailure( - op, "Non-constant row stride in 2D constant op."); - } + // Check row stride (skip first row) + if (r > 0 && rows > 1) { + int64_t prevIdx = (r - 1) * cols + c; + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != rowStride) + return rewriter.notifyMatchFailure( + op, "Non-constant row stride in constant op."); } } } @@ -829,12 +815,11 @@ struct WgToSgArithConstantOp : public OpConversionPattern { // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. // For 1D case, extract the first sgShape[0] elements. SmallVector baseTileValues; - int cols = sgShape[sgShape.size() - 1]; - int64_t wgCols = wgShape[sgShape.size() - 1]; - int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0]; - for (int64_t r = 0; r < rows; ++r) { - for (int64_t c = 0; c < cols; ++c) { - baseTileValues.push_back(values[r * wgCols + c]); + int baseTileCols = sgShape[sgShape.size() - 1]; + int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; + for (int64_t r = 0; r < baseTileRows; ++r) { + for (int64_t c = 0; c < baseTileCols; ++c) { + baseTileValues.push_back(values[r * cols + c]); } } @@ -851,10 +836,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern { return failure(); SmallVector strideConsts; - strideConsts.push_back( - rewriter.create(loc, rowStride)); strideConsts.push_back( rewriter.create(loc, colStride)); + if (rows > 1) + strideConsts.insert( + strideConsts.begin(), + rewriter.create(loc, rowStride)); + SmallVector newConstOps; Value mulOffset; for (auto offsets : *sgOffsets) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 5f990a49f1298..676c96db69236 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -459,8 +459,8 @@ gpu.module @test_distribution { gpu.return } - // CHECK-LABEL: non_splat_constant - gpu.func @non_splat_constant() { + // CHECK-LABEL: non_splat_constant_2D + gpu.func @non_splat_constant_2D() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index // CHECK-DAG: affine.apply #map4()[%[[SGID]]]