Skip to content

Commit 40cf489

Browse files
committed
[MLIR][SCF] Add support for vectorization hints in scf-to-cf lowering and provide an option to control it.
1 parent 71cf592 commit 40cf489

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,12 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
984984
def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> {
985985
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
986986
" control flow with a CFG";
987-
let dependentDialects = ["cf::ControlFlowDialect"];
987+
let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"];
988+
989+
let options = [Option<"enableVectorizeHits", "enable-vectorize-hits", "bool",
990+
/*default=*/"false",
991+
"Add vectorization hints when convert SCF parallel "
992+
"loop to ControlFlow dialect">];
988993
}
989994

990995
//===----------------------------------------------------------------------===//

mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class RewritePatternSet;
2020

2121
/// Collect a set of patterns to convert SCF operations to CFG branch-based
2222
/// operations within the ControlFlow dialect.
23-
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
23+
void populateSCFToControlFlowConversionPatterns(
24+
RewritePatternSet &patterns, bool enableVectorizeHits = false);
2425

2526
} // namespace mlir
2627

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace {
3838

3939
struct SCFToControlFlowPass
4040
: public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
41+
using Base::Base;
4142
void runOnOperation() override;
4243
};
4344

@@ -212,6 +213,11 @@ struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
212213
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
213214
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
214215

216+
bool enableVectorizeHits;
217+
218+
ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHits)
219+
: OpRewritePattern(ctx), enableVectorizeHits(enableVectorizeHits) {}
220+
215221
LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
216222
PatternRewriter &rewriter) const override;
217223
};
@@ -487,6 +493,13 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
487493
return failure();
488494
}
489495

496+
auto vecAttr = LLVM::LoopVectorizeAttr::get(
497+
rewriter.getContext(),
498+
/* disable */ rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
499+
auto loopAnnotation = LLVM::LoopAnnotationAttr::get(
500+
rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {},
501+
{}, {}, {}, {}, {}, {}, {});
502+
490503
// For a parallel loop, we essentially need to create an n-dimensional loop
491504
// nest. We do this by translating to scf.for ops and have those lowered in
492505
// a further rewrite. If a parallel loop contains reductions (and thus returns
@@ -517,6 +530,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
517530
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
518531
}
519532

533+
if (enableVectorizeHits)
534+
forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName(
535+
LLVM::BrOp::getOperationName(), getContext())),
536+
loopAnnotation);
537+
520538
rewriter.setInsertionPointToStart(forOp.getBody());
521539
}
522540

@@ -706,16 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
706724
}
707725

708726
void mlir::populateSCFToControlFlowConversionPatterns(
709-
RewritePatternSet &patterns) {
710-
patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
711-
WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
727+
RewritePatternSet &patterns, bool enableVectorizeHits) {
728+
patterns.add<ForallLowering, ForLowering, IfLowering, WhileLowering,
729+
ExecuteRegionLowering, IndexSwitchLowering>(
712730
patterns.getContext());
731+
patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHits);
713732
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
714733
}
715734

716735
void SCFToControlFlowPass::runOnOperation() {
717736
RewritePatternSet patterns(&getContext());
718-
populateSCFToControlFlowConversionPatterns(patterns);
737+
populateSCFToControlFlowConversionPatterns(patterns,
738+
enableVectorizeHits.getValue());
719739

720740
// Configure conversion to lower out SCF operations.
721741
ConversionTarget target(getContext());

mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC
2+
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hits=true" \
3+
// RUN: -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC
4+
5+
// VEC: #loop_vectorize = #llvm.loop_vectorize<disable = false>
6+
// VEC-NEXT: #[[$VEC_ATTR:.+]] = #llvm.loop_annotation<vectorize = #loop_vectorize>
27

38
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
49
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
@@ -332,7 +337,8 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
332337
// variable and the current partially reduced value.
333338
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
334339
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
335-
// CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
340+
// NO-VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
341+
// VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] {loop_annotation = #[[$VEC_ATTR]]}
336342

337343
// Bodies of scf.reduce operations are folded into the main loop body. The
338344
// result of this partial reduction is passed as argument to the condition
@@ -366,11 +372,13 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
366372
// CHECK: %[[INIT2:.*]] = arith.constant 42
367373
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
368374
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
369-
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
375+
// NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
376+
// VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
370377
// CHECK: ^[[BODY_OUT]]:
371378
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
372379
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
373-
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
380+
// NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
381+
// VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] {loop_annotation = #[[$VEC_ATTR]]}
374382
// CHECK: ^[[BODY_IN]]:
375383
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
376384
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
@@ -551,7 +559,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
551559
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
552560
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
553561
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
554-
// CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
562+
// NO-VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
563+
// VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
555564
// CHECK: ^[[LOOP_BODY]]:
556565
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
557566
// CHECK: ^[[IF1_THEN]]:
@@ -660,7 +669,8 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
660669
// CHECK: cf.br ^[[bb1:.*]](%[[c0]] : index)
661670
// CHECK: ^[[bb1]](%[[arg0:.*]]: index):
662671
// CHECK: %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
663-
// CHECK: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
672+
// NO-VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
673+
// VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] {loop_annotation = #[[$VEC_ATTR]]}
664674
// CHECK: ^[[bb2]]:
665675
// CHECK: "test.foo"(%[[arg0]])
666676
// CHECK: %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]

0 commit comments

Comments
 (0)