-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU] Distribute non-splat constant from wg to sg #161416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
6712150
7d3746a
1b00dc7
512478b
1381174
e77eddd
1b779b7
1b8db0e
fabb419
2c81dee
29d3f45
ff1bb3b
77afdcc
bbb427a
e10d769
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { | |
| ConversionPatternRewriter &rewriter) const override { | ||
| auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); | ||
| auto vecType = dyn_cast<VectorType>(op.getType()); | ||
| if (!vecAttr || !vecAttr.isSplat() || !vecType) | ||
| if (!vecAttr || !vecType) | ||
| return failure(); | ||
|
|
||
| xegpu::DistributeLayoutAttr layout = | ||
|
|
@@ -733,22 +733,172 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { | |
| int count; | ||
| std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); | ||
|
|
||
| // Current limitation: constant of vector with single value. | ||
| // TODO: support more complex cases, e.g., vector with multiple values. | ||
| Attribute singleVal = vecAttr.getSplatValue<Attribute>(); | ||
|
|
||
| auto newType = VectorType::get(sgShape, vecType.getElementType()); | ||
| auto sgAttr = DenseElementsAttr::get(newType, singleVal); | ||
| auto cstOp = | ||
| arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); | ||
| if (!layout.getEffectiveLaneLayoutAsInt().empty() || | ||
| !layout.getEffectiveInstDataAsInt().empty()) | ||
| xegpu::setDistributeLayoutAttr(cstOp->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
| SmallVector<Value> newConsts(count, cstOp); | ||
| Location loc = op.getLoc(); | ||
| auto eltType = vecType.getElementType(); | ||
|
|
||
| rewriter.replaceOpWithMultiple(op, {newConsts}); | ||
| return success(); | ||
| auto setLayoutIfNeeded = [&](Value val) { | ||
| if (!layout.getEffectiveLaneLayoutAsInt().empty() || | ||
| !layout.getEffectiveInstDataAsInt().empty()) { | ||
| xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), | ||
| layout.dropSgLayoutAndData()); | ||
| } | ||
| }; | ||
|
|
||
| if (vecAttr.isSplat()) { | ||
| // Splat: single value for all subgroups | ||
| Attribute singleVal = vecAttr.getSplatValue<Attribute>(); | ||
| auto sgAttr = DenseElementsAttr::get(newType, singleVal); | ||
| auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); | ||
| setLayoutIfNeeded(cstOp->getResult(0)); | ||
| rewriter.replaceOp(op, cstOp); | ||
| return success(); | ||
| } else if (sgShape == wgShape) { // if the entire vector is shared by all | ||
| // subgroups, don't distribute | ||
| auto newConstOp = | ||
| arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); | ||
| setLayoutIfNeeded(newConstOp->getResult(0)); | ||
| rewriter.replaceOp(op, newConstOp); | ||
| return success(); | ||
| } else { | ||
| // Non-splat constant | ||
| // Only supports 1D & 2D | ||
| // TODO: support other cases that require SLM access | ||
| if (!eltType.isIndex()) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unsupported element type for non-splat constant op."); | ||
|
|
||
| if (wgShape.size() > 2) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Only 1D & 2D vector constant supported"); | ||
|
|
||
| SmallVector<Attribute> values(vecAttr.getValues<Attribute>()); | ||
| int64_t stride = 0; | ||
| int64_t rowStride = 0, colStride = 0; | ||
| if (wgShape.size() == 1) { | ||
| // 1D case: single stride | ||
| if (values.size() > 1) { | ||
| stride = cast<IntegerAttr>(values[1]).getInt() - | ||
| cast<IntegerAttr>(values[0]).getInt(); | ||
| for (size_t i = 2; i < values.size(); ++i) { | ||
| int64_t diff = cast<IntegerAttr>(values[i]).getInt() - | ||
| cast<IntegerAttr>(values[i - 1]).getInt(); | ||
| if (diff != stride) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Non-constant stride in non-splat constant op."); | ||
| } | ||
| } | ||
| } else if (wgShape.size() == 2) { | ||
| // 2D case: row stride and column stride | ||
| int64_t rows = wgShape[0], cols = wgShape[1]; | ||
| // Compute col stride (stride between elements in a column) | ||
| if (cols > 1) { | ||
| colStride = cast<IntegerAttr>(values[1]).getInt() - | ||
| cast<IntegerAttr>(values[0]).getInt(); | ||
| for (int64_t r = 0; r < rows; ++r) { | ||
| for (int64_t c = 1; c < cols; ++c) { | ||
| int64_t idx = r * cols + c; | ||
| int64_t prevIdx = r * cols + (c - 1); | ||
| int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - | ||
| cast<IntegerAttr>(values[prevIdx]).getInt(); | ||
| if (diff != colStride) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Non-constant column stride in 2D constant op."); | ||
| } | ||
| } | ||
| } | ||
| // Compute row stride (stride between elements in a row) | ||
| if (rows > 1) { | ||
| rowStride = cast<IntegerAttr>(values[cols]).getInt() - | ||
| cast<IntegerAttr>(values[0]).getInt(); | ||
| for (int64_t c = 0; c < cols; ++c) { | ||
| for (int64_t r = 1; r < rows; ++r) { | ||
| int64_t idx = r * cols + c; | ||
| int64_t prevIdx = (r - 1) * cols + c; | ||
| int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - | ||
| cast<IntegerAttr>(values[prevIdx]).getInt(); | ||
| if (diff != rowStride) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Non-constant row stride in 2D constant op."); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Determine the shape of the base tile for each subgroup. | ||
| SmallVector<int64_t> baseTileShape; | ||
|
||
| if (sgShape.size() == 1) { | ||
| baseTileShape.push_back(sgShape[0]); | ||
| } else if (sgShape.size() == 2) { | ||
| baseTileShape = sgShape; | ||
| } else { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Only 1D or 2D vector constant supported"); | ||
| } | ||
|
|
||
| // Create a constant for the base tile. | ||
| // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. | ||
| SmallVector<Attribute> baseTileValues; | ||
| if (baseTileShape.size() == 2) { | ||
| int64_t rows = baseTileShape[0], cols = baseTileShape[1]; | ||
| int64_t wgCols = wgShape[1]; | ||
| for (int64_t r = 0; r < rows; ++r) { | ||
| for (int64_t c = 0; c < cols; ++c) { | ||
| baseTileValues.push_back(values[r * wgCols + c]); | ||
| } | ||
| } | ||
| } else { | ||
| // 1D case | ||
| for (int64_t i = 0; i < computeProduct(baseTileShape); ++i) | ||
| baseTileValues.push_back(values[i]); | ||
| } | ||
| auto tileAttr = DenseElementsAttr::get( | ||
| VectorType::get(baseTileShape, eltType), baseTileValues); | ||
| auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); | ||
|
|
||
| // Get subgroup id | ||
| Value sgId = | ||
| gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); | ||
|
|
||
| auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); | ||
| if (failed(sgOffsets)) | ||
| return failure(); | ||
|
|
||
| auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride); | ||
| auto rowStrideConst = | ||
| rewriter.create<arith::ConstantIndexOp>(loc, rowStride); | ||
| auto colStrideConst = | ||
| rewriter.create<arith::ConstantIndexOp>(loc, colStride); | ||
| SmallVector<Value> newConstOps; | ||
| for (auto offsets : *sgOffsets) { | ||
| // Multiply offset with stride, broadcast it and add to baseConstVec | ||
| Value mulOffset; | ||
| if (wgShape.size() == 1) { | ||
| // 1D: offset[0] * strideConst | ||
| mulOffset = rewriter.create<arith::MulIOp>( | ||
| loc, rewriter.getIndexType(), offsets[0], strideConst); | ||
| } else if (wgShape.size() == 2) { | ||
|
||
| // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst | ||
| Value rowMul = rewriter.create<arith::MulIOp>( | ||
| loc, rewriter.getIndexType(), offsets[0], rowStrideConst); | ||
| Value colMul = rewriter.create<arith::MulIOp>( | ||
| loc, rewriter.getIndexType(), offsets[1], colStrideConst); | ||
| mulOffset = rewriter.create<arith::AddIOp>( | ||
| loc, rewriter.getIndexType(), rowMul, colMul); | ||
| } | ||
| // Broadcast to baseConstVec size | ||
| auto bcastOffset = rewriter.create<vector::BroadcastOp>( | ||
| loc, baseConstVec.getType(), mulOffset); | ||
| auto finalConst = | ||
| arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); | ||
| setLayoutIfNeeded(baseConstVec); | ||
| setLayoutIfNeeded(bcastOffset); | ||
| setLayoutIfNeeded(finalConst); | ||
| newConstOps.push_back(finalConst); | ||
| } | ||
| rewriter.replaceOpWithMultiple(op, {newConstOps}); | ||
| return success(); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.