Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,12 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> {
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
" control flow with a CFG";
let dependentDialects = ["cf::ControlFlowDialect"];
let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"];

let options = [Option<"enableVectorizeHits", "enable-vectorize-hits", "bool",
/*default=*/"false",
"Add vectorization hints when convert SCF parallel "
"loop to ControlFlow dialect">];
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class RewritePatternSet;

/// Collect a set of patterns to convert SCF operations to CFG branch-based
/// operations within the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
void populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns, bool enableVectorizeHits = false);

} // namespace mlir

Expand Down
28 changes: 24 additions & 4 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace {

struct SCFToControlFlowPass
: public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
using Base::Base;
void runOnOperation() override;
};

Expand Down Expand Up @@ -212,6 +213,11 @@ struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;

bool enableVectorizeHits;

ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHits)
: OpRewritePattern(ctx), enableVectorizeHits(enableVectorizeHits) {}

LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
};
Expand Down Expand Up @@ -487,6 +493,13 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
return failure();
}

auto vecAttr = LLVM::LoopVectorizeAttr::get(
rewriter.getContext(),
/* disable */ rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
auto loopAnnotation = LLVM::LoopAnnotationAttr::get(
rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {},
{}, {}, {}, {}, {}, {}, {});

// For a parallel loop, we essentially need to create an n-dimensional loop
// nest. We do this by translating to scf.for ops and have those lowered in
// a further rewrite. If a parallel loop contains reductions (and thus returns
Expand Down Expand Up @@ -517,6 +530,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
}

if (enableVectorizeHits)
forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName(
LLVM::BrOp::getOperationName(), getContext())),
loopAnnotation);

rewriter.setInsertionPointToStart(forOp.getBody());
}

Expand Down Expand Up @@ -706,16 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
}

void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
RewritePatternSet &patterns, bool enableVectorizeHits) {
patterns.add<ForallLowering, ForLowering, IfLowering, WhileLowering,
ExecuteRegionLowering, IndexSwitchLowering>(
patterns.getContext());
patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHits);
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}

void SCFToControlFlowPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateSCFToControlFlowConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns,
enableVectorizeHits.getValue());

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
Expand Down
22 changes: 16 additions & 6 deletions mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hits=true" \
// RUN: -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC

// VEC: #loop_vectorize = #llvm.loop_vectorize<disable = false>
// VEC-NEXT: #[[$VEC_ATTR:.+]] = #llvm.loop_annotation<vectorize = #loop_vectorize>

// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
Expand Down Expand Up @@ -332,7 +337,8 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// variable and the current partially reduced value.
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
// CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// NO-VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] {loop_annotation = #[[$VEC_ATTR]]}

// Bodies of scf.reduce operations are folded into the main loop body. The
// result of this partial reduction is passed as argument to the condition
Expand Down Expand Up @@ -366,11 +372,13 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: %[[INIT2:.*]] = arith.constant 42
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[BODY_OUT]]:
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[BODY_IN]]:
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
Expand Down Expand Up @@ -551,7 +559,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
// CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
// NO-VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
// VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[LOOP_BODY]]:
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
// CHECK: ^[[IF1_THEN]]:
Expand Down Expand Up @@ -660,7 +669,8 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
// CHECK: cf.br ^[[bb1:.*]](%[[c0]] : index)
// CHECK: ^[[bb1]](%[[arg0:.*]]: index):
// CHECK: %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
// CHECK: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
// NO-VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
// VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[bb2]]:
// CHECK: "test.foo"(%[[arg0]])
// CHECK: %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]
Expand Down