Skip to content

Commit 3b83e7f

Browse files
authored
[flang] Implement !DIR$ IVDEP directive (#133728)
This directive tells the compiler to ignore vector dependencies in the following loop and it must be placed before a `do loop`. Sometimes the compiler may not have sufficient information to decide whether a particular loop is vectorizable due to potential dependencies between iterations and the directive is here to tell to the compiler that vectorization is safe with `parallelAccesses` metadata. This directive is also equivalent to `#pragma clang loop assume(safety)` in C++
1 parent 75af8e8 commit 3b83e7f

File tree

18 files changed

+339
-40
lines changed

18 files changed

+339
-40
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,8 @@ void genScalarAssignment(fir::FirOpBuilder &builder, mlir::Location loc,
820820
const fir::ExtendedValue &lhs,
821821
const fir::ExtendedValue &rhs,
822822
bool needFinalization = false,
823-
bool isTemporaryLHS = false);
823+
bool isTemporaryLHS = false,
824+
mlir::ArrayAttr accessGroups = {});
824825

825826
/// Assign \p rhs to \p lhs. Both \p rhs and \p lhs must be scalar derived
826827
/// types. The assignment follows Fortran intrinsic assignment semantic for

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
306306
}];
307307

308308
let arguments = (ins AnyReferenceLike:$memref,
309-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
309+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal,
310+
OptionalAttr<LLVM_AccessGroupArrayAttr>:$accessGroups);
310311

311312
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
312313
OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];
@@ -339,7 +340,8 @@ def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
339340
}];
340341

341342
let arguments = (ins AnyType:$value, AnyReferenceLike:$memref,
342-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
343+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal,
344+
OptionalAttr<LLVM_AccessGroupArrayAttr>:$accessGroups);
343345

344346
let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];
345347

@@ -2575,16 +2577,14 @@ def fir_CallOp : fir_Op<"call",
25752577
```
25762578
}];
25772579

2578-
let arguments = (ins
2579-
OptionalAttr<SymbolRefAttr>:$callee,
2580-
Variadic<AnyType>:$args,
2581-
OptionalAttr<DictArrayAttr>:$arg_attrs,
2582-
OptionalAttr<DictArrayAttr>:$res_attrs,
2583-
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
2584-
OptionalAttr<fir_FortranInlineAttr>:$inline_attr,
2585-
DefaultValuedAttr<Arith_FastMathAttr,
2586-
"::mlir::arith::FastMathFlags::none">:$fastmath
2587-
);
2580+
let arguments = (ins OptionalAttr<SymbolRefAttr>:$callee,
2581+
Variadic<AnyType>:$args, OptionalAttr<DictArrayAttr>:$arg_attrs,
2582+
OptionalAttr<DictArrayAttr>:$res_attrs,
2583+
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
2584+
OptionalAttr<fir_FortranInlineAttr>:$inline_attr,
2585+
OptionalAttr<LLVM_AccessGroupArrayAttr>:$accessGroups,
2586+
DefaultValuedAttr<Arith_FastMathAttr,
2587+
"::mlir::arith::FastMathFlags::none">:$fastmath);
25882588
let results = (outs Variadic<AnyType>);
25892589

25902590
let hasCustomAssemblyFormat = 1;

flang/include/flang/Parser/dump-parse-tree.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ class ParseTreeDumper {
207207
NODE(CompilerDirective, AssumeAligned)
208208
NODE(CompilerDirective, IgnoreTKR)
209209
NODE(CompilerDirective, Inline)
210+
NODE(CompilerDirective, IVDep)
210211
NODE(CompilerDirective, ForceInline)
211212
NODE(CompilerDirective, LoopCount)
212213
NODE(CompilerDirective, NameValue)

flang/include/flang/Parser/parse-tree.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3364,6 +3364,7 @@ struct StmtFunctionStmt {
33643364
// !DIR$ FORCEINLINE
33653365
// !DIR$ INLINE
33663366
// !DIR$ NOINLINE
3367+
// !DIR$ IVDEP
33673368
// !DIR$ <anything else>
33683369
struct CompilerDirective {
33693370
UNION_CLASS_BOILERPLATE(CompilerDirective);
@@ -3399,12 +3400,13 @@ struct CompilerDirective {
33993400
EMPTY_CLASS(ForceInline);
34003401
EMPTY_CLASS(Inline);
34013402
EMPTY_CLASS(NoInline);
3403+
EMPTY_CLASS(IVDep);
34023404
EMPTY_CLASS(Unrecognized);
34033405
CharBlock source;
34043406
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
34053407
VectorAlways, std::list<NameValue>, Unroll, UnrollAndJam, Unrecognized,
34063408
NoVector, NoUnroll, NoUnrollAndJam, ForceInline, Inline, NoInline,
3407-
Prefetch>
3409+
Prefetch, IVDep>
34083410
u;
34093411
};
34103412

flang/lib/Lower/Bridge.cpp

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23042304
}
23052305
}
23062306

2307+
// Add AccessGroups attribute on operations in fir::DoLoopOp if this
2308+
// operation has the parallelAccesses attribute.
2309+
void attachAccessGroupAttrToDoLoopOperations(fir::DoLoopOp &doLoop) {
2310+
if (auto loopAnnotAttr = doLoop.getLoopAnnotationAttr()) {
2311+
if (loopAnnotAttr.getParallelAccesses().size()) {
2312+
llvm::SmallVector<mlir::Attribute> accessGroupAttrs(
2313+
loopAnnotAttr.getParallelAccesses().begin(),
2314+
loopAnnotAttr.getParallelAccesses().end());
2315+
mlir::ArrayAttr attrs =
2316+
mlir::ArrayAttr::get(builder->getContext(), accessGroupAttrs);
2317+
doLoop.walk([&](mlir::Operation *op) {
2318+
if (fir::StoreOp storeOp = mlir::dyn_cast<fir::StoreOp>(op)) {
2319+
storeOp.setAccessGroupsAttr(attrs);
2320+
} else if (fir::LoadOp loadOp = mlir::dyn_cast<fir::LoadOp>(op)) {
2321+
loadOp.setAccessGroupsAttr(attrs);
2322+
} else if (hlfir::AssignOp assignOp =
2323+
mlir::dyn_cast<hlfir::AssignOp>(op)) {
2324+
// In some loops, the HLFIR AssignOp operation can be translated
2325+
// into FIR operation(s) containing StoreOp. It is therefore
2326+
// necessary to forward the AccessGroups attribute.
2327+
assignOp.getOperation()->setAttr("access_groups", attrs);
2328+
} else if (fir::CallOp callOp = mlir::dyn_cast<fir::CallOp>(op)) {
2329+
callOp.setAccessGroupsAttr(attrs);
2330+
}
2331+
});
2332+
}
2333+
}
2334+
}
2335+
23072336
/// Generate FIR for a DO construct. There are six variants:
23082337
/// - unstructured infinite and while loops
23092338
/// - structured and unstructured increment loops
@@ -2452,6 +2481,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
24522481
// This call may generate a branch in some contexts.
24532482
genFIR(endDoEval, unstructuredContext);
24542483

2484+
// Add AccessGroups attribute on operations in fir::DoLoopOp if necessary
2485+
for (IncrementLoopInfo &info : incrementLoopNestInfo)
2486+
if (auto loopOp = mlir::dyn_cast_if_present<fir::DoLoopOp>(info.loopOp))
2487+
attachAccessGroupAttrToDoLoopOperations(loopOp);
2488+
24552489
if (!incrementLoopNestInfo.empty() &&
24562490
incrementLoopNestInfo.back().isConcurrent)
24572491
localSymbols.popScope();
@@ -2540,22 +2574,31 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25402574
{}, {}, {}, {});
25412575
}
25422576

2577+
// Enabling loop vectorization attribute.
2578+
mlir::LLVM::LoopVectorizeAttr
2579+
genLoopVectorizeAttr(mlir::BoolAttr disableAttr) {
2580+
mlir::LLVM::LoopVectorizeAttr va;
2581+
if (disableAttr)
2582+
va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
2583+
/*disable=*/disableAttr, {}, {},
2584+
{}, {}, {}, {});
2585+
return va;
2586+
}
2587+
25432588
void addLoopAnnotationAttr(
25442589
IncrementLoopInfo &info,
25452590
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
2546-
mlir::LLVM::LoopVectorizeAttr va;
2591+
mlir::BoolAttr disableVecAttr;
25472592
mlir::LLVM::LoopUnrollAttr ua;
25482593
mlir::LLVM::LoopUnrollAndJamAttr uja;
2594+
llvm::SmallVector<mlir::LLVM::AccessGroupAttr> aga;
25492595
bool has_attrs = false;
25502596
for (const auto *dir : dirs) {
25512597
Fortran::common::visit(
25522598
Fortran::common::visitors{
25532599
[&](const Fortran::parser::CompilerDirective::VectorAlways &) {
2554-
mlir::BoolAttr falseAttr =
2600+
disableVecAttr =
25552601
mlir::BoolAttr::get(builder->getContext(), false);
2556-
va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
2557-
/*disable=*/falseAttr,
2558-
{}, {}, {}, {}, {}, {});
25592602
has_attrs = true;
25602603
},
25612604
[&](const Fortran::parser::CompilerDirective::Unroll &u) {
@@ -2567,11 +2610,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25672610
has_attrs = true;
25682611
},
25692612
[&](const Fortran::parser::CompilerDirective::NoVector &u) {
2570-
mlir::BoolAttr trueAttr =
2613+
disableVecAttr =
25712614
mlir::BoolAttr::get(builder->getContext(), true);
2572-
va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
2573-
/*disable=*/trueAttr,
2574-
{}, {}, {}, {}, {}, {});
25752615
has_attrs = true;
25762616
},
25772617
[&](const Fortran::parser::CompilerDirective::NoUnroll &u) {
@@ -2582,13 +2622,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25822622
uja = genLoopUnrollAndJamAttr(/*unrollingFactor=*/0);
25832623
has_attrs = true;
25842624
},
2585-
2625+
[&](const Fortran::parser::CompilerDirective::IVDep &iv) {
2626+
disableVecAttr =
2627+
mlir::BoolAttr::get(builder->getContext(), false);
2628+
aga.push_back(
2629+
mlir::LLVM::AccessGroupAttr::get(builder->getContext()));
2630+
has_attrs = true;
2631+
},
25862632
[&](const auto &) {}},
25872633
dir->u);
25882634
}
2635+
mlir::LLVM::LoopVectorizeAttr va = genLoopVectorizeAttr(disableVecAttr);
25892636
mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
25902637
builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua,
2591-
/*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {}, {});
2638+
/*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {},
2639+
/*parallelAccesses*/ aga);
25922640
if (has_attrs) {
25932641
if (auto loopOp = mlir::dyn_cast<fir::DoLoopOp>(info.loopOp))
25942642
loopOp.setLoopAnnotationAttr(la);
@@ -3318,6 +3366,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
33183366
[&](const Fortran::parser::CompilerDirective::Prefetch &prefetch) {
33193367
TODO(getCurrentLocation(), "!$dir prefetch");
33203368
},
3369+
[&](const Fortran::parser::CompilerDirective::IVDep &) {
3370+
attachDirectiveToLoop(dir, &eval);
3371+
},
33213372
[&](const auto &) {}},
33223373
dir.u);
33233374
}

flang/lib/Lower/ConvertCall.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ Fortran::lower::genCallOpAndResult(
713713
builder.getContext(), fir::FortranInlineEnum::always_inline);
714714
auto call = fir::CallOp::create(
715715
builder, loc, funcType.getResults(), funcSymbolAttr, operands,
716-
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr);
716+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr,
717+
/*accessGroups=*/mlir::ArrayAttr{});
717718

718719
callNumResults = call.getNumResults();
719720
if (callNumResults != 0)

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,12 +1393,10 @@ fir::ExtendedValue fir::factory::arraySectionElementToExtendedValue(
13931393
return fir::factory::componentToExtendedValue(builder, loc, element);
13941394
}
13951395

1396-
void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder,
1397-
mlir::Location loc,
1398-
const fir::ExtendedValue &lhs,
1399-
const fir::ExtendedValue &rhs,
1400-
bool needFinalization,
1401-
bool isTemporaryLHS) {
1396+
void fir::factory::genScalarAssignment(
1397+
fir::FirOpBuilder &builder, mlir::Location loc,
1398+
const fir::ExtendedValue &lhs, const fir::ExtendedValue &rhs,
1399+
bool needFinalization, bool isTemporaryLHS, mlir::ArrayAttr accessGroups) {
14021400
assert(lhs.rank() == 0 && rhs.rank() == 0 && "must be scalars");
14031401
auto type = fir::unwrapSequenceType(
14041402
fir::unwrapPassByRefType(fir::getBase(lhs).getType()));
@@ -1420,7 +1418,9 @@ void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder,
14201418
mlir::Value lhsAddr = fir::getBase(lhs);
14211419
rhsVal = builder.createConvert(loc, fir::unwrapRefType(lhsAddr.getType()),
14221420
rhsVal);
1423-
fir::StoreOp::create(builder, loc, rhsVal, lhsAddr);
1421+
fir::StoreOp store = fir::StoreOp::create(builder, loc, rhsVal, lhsAddr);
1422+
if (accessGroups)
1423+
store.setAccessGroupsAttr(accessGroups);
14241424
}
14251425
}
14261426

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,10 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
692692
}
693693
}
694694

695+
if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
696+
call.getAccessGroups())
697+
llvmCall.setAccessGroups(*optionalAccessGroups);
698+
695699
if (memAttr)
696700
llvmCall.setMemoryEffectsAttr(
697701
mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr));
@@ -3402,6 +3406,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
34023406
loadOp.setTBAATags(*optionalTag);
34033407
else
34043408
attachTBAATag(loadOp, load.getType(), load.getType(), nullptr);
3409+
if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
3410+
load.getAccessGroups())
3411+
loadOp.setAccessGroups(*optionalAccessGroups);
34053412
rewriter.replaceOp(load, loadOp.getResult());
34063413
}
34073414
return mlir::success();
@@ -3733,6 +3740,10 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
37333740
if (store.getNontemporal())
37343741
storeOp.setNontemporal(true);
37353742

3743+
if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
3744+
store.getAccessGroups())
3745+
storeOp.setAccessGroups(*optionalAccessGroups);
3746+
37363747
newOp = storeOp;
37373748
}
37383749
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4285,7 +4285,7 @@ llvm::LogicalResult fir::StoreOp::verify() {
42854285

42864286
void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
42874287
mlir::Value value, mlir::Value memref) {
4288-
build(builder, result, value, memref, {});
4288+
build(builder, result, value, memref, {}, {}, {});
42894289
}
42904290

42914291
void fir::StoreOp::getEffects(

flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,18 @@ class AssignOpConversion : public mlir::OpRewritePattern<hlfir::AssignOp> {
149149
!assignOp.isTemporaryLHS() &&
150150
mlir::isa<fir::RecordType>(fir::getElementTypeOf(lhsExv));
151151

152+
mlir::ArrayAttr accessGroups;
153+
if (auto attrs = assignOp.getOperation()->getAttrOfType<mlir::ArrayAttr>(
154+
"access_groups"))
155+
accessGroups = attrs;
156+
152157
// genScalarAssignment() must take care of potential overlap
153158
// between LHS and RHS. Note that the overlap is possible
154159
// also for components of LHS/RHS, and the Assign() runtime
155160
// must take care of it.
156-
fir::factory::genScalarAssignment(builder, loc, lhsExv, rhsExv,
157-
needFinalization,
158-
assignOp.isTemporaryLHS());
161+
fir::factory::genScalarAssignment(
162+
builder, loc, lhsExv, rhsExv, needFinalization,
163+
assignOp.isTemporaryLHS(), accessGroups);
159164
}
160165
rewriter.eraseOp(assignOp);
161166
return mlir::success();

0 commit comments

Comments
 (0)