diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a9296b184..659039b41638d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast(op.getValue()); auto vecType = dyn_cast(op.getType()); - if (!vecAttr || !vecAttr.isSplat() || !vecType) + if (!vecAttr || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = @@ -733,22 +733,143 @@ struct WgToSgArithConstantOp : public OpConversionPattern { int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - // Current limitation: constant of vector with single value. - // TODO: support more complex cases, e.g., vector with multiple values. - Attribute singleVal = vecAttr.getSplatValue(); - auto newType = VectorType::get(sgShape, vecType.getElementType()); - auto sgAttr = DenseElementsAttr::get(newType, singleVal); - auto cstOp = - arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), - layout.dropSgLayoutAndData()); - SmallVector newConsts(count, cstOp); + Location loc = op.getLoc(); + auto eltType = vecType.getElementType(); - rewriter.replaceOpWithMultiple(op, {newConsts}); - return success(); + auto setLayoutIfNeeded = [&](Value val) { + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast(val), + layout.dropSgLayoutAndData()); + } + }; + + if (vecAttr.isSplat()) { + // Splat: single value for all subgroups + Attribute singleVal = vecAttr.getSplatValue(); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); + setLayoutIfNeeded(cstOp->getResult(0)); + rewriter.replaceOp(op, cstOp); + return success(); + } else if (sgShape == wgShape) { // if the entire vector is shared by all + // subgroups, don't distribute + auto newConstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); + setLayoutIfNeeded(newConstOp->getResult(0)); + rewriter.replaceOp(op, newConstOp); + return success(); + } else { + // Non-splat constant + // Only supports 1D & 2D + // TODO: support other cases that require SLM access + if (!eltType.isIndex()) + return rewriter.notifyMatchFailure( + op, "Unsupported element type for non-splat constant op."); + + if (wgShape.size() > 2) + return rewriter.notifyMatchFailure( + op, "Only 1D & 2D vector constant supported"); + + SmallVector values(vecAttr.getValues()); + int64_t rowStride = 0, colStride = 0; + int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; + int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; + + // Compute colStride and rowStride, and check for constant strides. + if (cols > 1) { + colStride = cast(values[1]).getInt() - + cast(values[0]).getInt(); + } + if (rows > 1) { + rowStride = cast(values[cols]).getInt() - + cast(values[0]).getInt(); + } + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + int64_t idx = r * cols + c; + // Check column stride (skip first column) + if (c > 0 && cols > 1) { + int64_t prevIdx = r * cols + (c - 1); + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != colStride) + return rewriter.notifyMatchFailure( + op, "Non-constant column stride in constant op."); + } + // Check row stride (skip first row) + if (r > 0 && rows > 1) { + int64_t prevIdx = (r - 1) * cols + c; + int64_t diff = cast(values[idx]).getInt() - + cast(values[prevIdx]).getInt(); + if (diff != rowStride) + return rewriter.notifyMatchFailure( + op, "Non-constant row stride in constant op."); + } + } + } + + // Create a constant for the base tile. + // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. + // For 1D case, extract the first sgShape[0] elements. + SmallVector baseTileValues; + int baseTileCols = sgShape[sgShape.size() - 1]; + int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; + for (int64_t r = 0; r < baseTileRows; ++r) { + for (int64_t c = 0; c < baseTileCols; ++c) { + baseTileValues.push_back(values[r * cols + c]); + } + } + + auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), + baseTileValues); + auto baseConstVec = rewriter.create(loc, tileAttr); + + // Get subgroup id + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector strideConsts; + strideConsts.push_back( + rewriter.create(loc, colStride)); + if (rows > 1) + strideConsts.insert( + strideConsts.begin(), + rewriter.create(loc, rowStride)); + + SmallVector newConstOps; + Value mulOffset; + for (auto offsets : *sgOffsets) { + // Multiply offset with stride, broadcast it and add to baseConstVec + SmallVector muls; + for (size_t i = 0; i < strideConsts.size(); ++i) { + muls.push_back(rewriter.create( + loc, rewriter.getIndexType(), offsets[i], strideConsts[i])); + } + mulOffset = muls.front(); + if (muls.size() > 1) { + mulOffset = rewriter.create( + loc, rewriter.getIndexType(), mulOffset, muls[1]); + } + // Broadcast to baseConstVec size + auto bcastOffset = rewriter.create( + loc, baseConstVec.getType(), mulOffset); + auto finalConst = + arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); + setLayoutIfNeeded(baseConstVec); + setLayoutIfNeeded(bcastOffset); + setLayoutIfNeeded(finalConst); + newConstOps.push_back(finalConst); + } + rewriter.replaceOpWithMultiple(op, {newConstOps}); + return success(); + } } }; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index dce73dee507e1..c2e51bdb71485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -98,4 +98,29 @@ gpu.module @test_distribution { : vector<256x64xf32> to vector<256xf32> gpu.return } + + gpu.func @non_splat_constant() { + // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]] + // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]] + // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]] + // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index + // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]] + // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]] + // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[STRIDE1]], %[[STRIDE2]] : index + // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex> + // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex> + // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[STRIDE3]], %[[STRIDE4]] : index + // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES2]] : index to vector<2x1xindex> + // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 03c63861705d9..676c96db69236 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -458,4 +458,51 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout} : index to vector<4x2x6x32xindex> gpu.return } + + // CHECK-LABEL: non_splat_constant_2D + gpu.func @non_splat_constant_2D() { + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: affine.apply #map4()[%[[SGID]]] + // CHECK-DAG: affine.apply #map5()[%[[SGID]]] + // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]] + // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]] + // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index + // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index + // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[STRIDECOL]], %[[STRIDEROW]] : index + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex> + // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex> + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> + gpu.return + } + + // CHECK-LABEL: non_splat_constant_2D_non_unit_dim + gpu.func @non_splat_constant_2D_non_unit_dim() { + // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]] + // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]] + // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]] + // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]] + // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]] + // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index + // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]] + // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index + // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index + // CHECK-DAG: %[[ADDIDX:.*]] = arith.addi %[[MUL5]], %[[MUL6]] : index + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDIDX]] : index to vector<2x2xindex> + // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex> + %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout} dense<[ + [0, 16, 32, 48, 64, 80, 96, 112], + [8, 24, 40, 56, 72, 88, 104, 120], + [16, 32, 48, 64, 80, 96, 112, 128], + [24, 40, 56, 72, 88, 104, 120, 136], + [32, 48, 64, 80, 96, 112, 128, 144], + [40, 56, 72, 88, 104, 120, 136, 152], + [48, 64, 80, 96, 112, 128, 144, 160], + [56, 72, 88, 104, 120, 136, 152, 168] + ]> : vector<8x8xindex> + gpu.return + } }