@@ -210,13 +210,21 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210210struct ScalarSubgroupReduceToShuffles final
211211 : OpRewritePattern<gpu::SubgroupReduceOp> {
212212 ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
213- unsigned shuffleBitwidth,
213+ unsigned shuffleBitwidth, bool matchClustered,
214214 PatternBenefit benefit)
215215 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216- shuffleBitwidth (shuffleBitwidth) {}
216+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
217217
218218 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
219219 PatternRewriter &rewriter) const override {
220+ if (op.getClusterSize ().has_value () != matchClustered) {
221+ return rewriter.notifyMatchFailure (
222+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
223+ " only match {1}clustered ops" ,
224+ matchClustered ? " non-" : " " ,
225+ matchClustered ? " " : " non-" ));
226+ }
227+
220228 auto ci = getAndValidateClusterInfo (op, subgroupSize);
221229 if (failed (ci))
222230 return failure ();
@@ -262,19 +270,28 @@ struct ScalarSubgroupReduceToShuffles final
262270private:
263271 unsigned subgroupSize = 0 ;
264272 unsigned shuffleBitwidth = 0 ;
273+ bool matchClustered = false ;
265274};
266275
267276// / Lowers vector gpu subgroup reductions to a series of shuffles.
268277struct VectorSubgroupReduceToShuffles final
269278 : OpRewritePattern<gpu::SubgroupReduceOp> {
270279 VectorSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
271- unsigned shuffleBitwidth,
280+ unsigned shuffleBitwidth, bool matchClustered,
272281 PatternBenefit benefit)
273282 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274- shuffleBitwidth (shuffleBitwidth) {}
283+ shuffleBitwidth (shuffleBitwidth), matchClustered(matchClustered) {}
275284
276285 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
277286 PatternRewriter &rewriter) const override {
287+ if (op.getClusterSize ().has_value () != matchClustered) {
288+ return rewriter.notifyMatchFailure (
289+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
290+ " only match {1}clustered ops" ,
291+ matchClustered ? " non-" : " " ,
292+ matchClustered ? " " : " non-" ));
293+ }
294+
278295 auto ci = getAndValidateClusterInfo (op, subgroupSize);
279296 if (failed (ci))
280297 return failure ();
@@ -343,6 +360,7 @@ struct VectorSubgroupReduceToShuffles final
343360private:
344361 unsigned subgroupSize = 0 ;
345362 unsigned shuffleBitwidth = 0 ;
363+ bool matchClustered = false ;
346364};
347365} // namespace
348366
@@ -358,5 +376,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358376 RewritePatternSet &patterns, unsigned subgroupSize,
359377 unsigned shuffleBitwidth, PatternBenefit benefit) {
360378 patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361- patterns.getContext (), subgroupSize, shuffleBitwidth, benefit);
379+ patterns.getContext (), subgroupSize, shuffleBitwidth,
380+ /* matchClustered=*/ false , benefit);
381+ }
382+
383+ void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns (
384+ RewritePatternSet &patterns, unsigned subgroupSize,
385+ unsigned shuffleBitwidth, PatternBenefit benefit) {
386+ patterns.add <ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387+ patterns.getContext (), subgroupSize, shuffleBitwidth,
388+ /* matchClustered=*/ true , benefit);
362389}
0 commit comments