@@ -22,38 +22,39 @@ namespace mlir {
2222namespace gpu {
2323struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
2424 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
25+
2526 virtual LogicalResult
2627 matchAndRewrite (WarpExecuteOnLane0Op op,
2728 PatternRewriter &rewriter) const override = 0 ;
2829
2930protected:
3031 // / Return a value yielded by `warpOp` which statifies the filter lamdba
3132 // / condition and is not dead.
32- static OpOperand *getWarpResult (WarpExecuteOnLane0Op warpOp,
33- const std::function<bool (Operation *)> &fn);
33+ OpOperand *getWarpResult (WarpExecuteOnLane0Op warpOp,
34+ const std::function<bool (Operation *)> &fn) const ;
3435
3536 // / Helper to create a new WarpExecuteOnLane0Op with different signature.
36- static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns (
37+ WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns (
3738 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
38- ValueRange newYieldedValues, TypeRange newReturnTypes);
39+ ValueRange newYieldedValues, TypeRange newReturnTypes) const ;
3940
4041 // / Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
4142 // / `indices` return the index of each new output.
42- static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns (
43+ WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns (
4344 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
4445 ValueRange newYieldedValues, TypeRange newReturnTypes,
45- llvm::SmallVector<size_t > &indices);
46+ llvm::SmallVector<size_t > &indices) const ;
4647
4748 // / Delinearize the given `laneId` into multiple dimensions, where each
4849 // / dimension's size is determined by `originalShape` and `distributedShape`
4950 // / together. This function expects the total numbers of threads needed for
5051 // / distribution is equal to `warpSize`. Returns true and updates
5152 // / `delinearizedIds` if so.
52- static bool delinearizeLaneId (OpBuilder &builder, Location loc,
53- ArrayRef<int64_t > originalShape,
54- ArrayRef<int64_t > distributedShape,
55- int64_t warpSize, Value laneId,
56- SmallVectorImpl<Value> &delinearizedIds);
53+ bool delinearizeLaneId (OpBuilder &builder, Location loc,
54+ ArrayRef<int64_t > originalShape,
55+ ArrayRef<int64_t > distributedShape, int64_t warpSize ,
56+ Value laneId,
57+ SmallVectorImpl<Value> &delinearizedIds) const ;
5758};
5859
5960} // namespace gpu
0 commit comments