@@ -1123,6 +1123,8 @@ addReductionDecl(mlir::Location currentLocation,
11231123 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
11241124 if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
11251125 mlir::Value symVal = converter.getSymbolAddress (*symbol);
1126+ if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1127+ symVal = declOp.getBase ();
11261128 mlir::Type redType =
11271129 symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
11281130 reductionVars.push_back (symVal);
@@ -1160,6 +1162,8 @@ addReductionDecl(mlir::Location currentLocation,
11601162 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
11611163 if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
11621164 mlir::Value symVal = converter.getSymbolAddress (*symbol);
1165+ if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1166+ symVal = declOp.getBase ();
11631167 mlir::Type redType =
11641168 symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
11651169 reductionVars.push_back (symVal);
@@ -3746,6 +3750,8 @@ void Fortran::lower::genOpenMPReduction(
37463750 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
37473751 if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
37483752 mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
3753+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
3754+ reductionVal = declOp.getBase ();
37493755 mlir::Type reductionType =
37503756 reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
37513757 if (!reductionType.isa <fir::LogicalType>()) {
@@ -3789,6 +3795,9 @@ void Fortran::lower::genOpenMPReduction(
37893795 ompObject)}) {
37903796 if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
37913797 mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
3798+ if (auto declOp =
3799+ reductionVal.getDefiningOp <hlfir::DeclareOp>())
3800+ reductionVal = declOp.getBase ();
37923801 for (const mlir::OpOperand &reductionValUse :
37933802 reductionVal.getUses ()) {
37943803 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
@@ -3844,6 +3853,13 @@ mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
38443853 return reductionOp;
38453854 }
38463855 }
3856+ if (auto assign =
3857+ mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner ())) {
3858+ if (assign.getLhs () == *reductionVal) {
3859+ assign.erase ();
3860+ return reductionOp;
3861+ }
3862+ }
38473863 }
38483864 }
38493865 }
@@ -3899,6 +3915,11 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
38993915 if (storeOp.getMemref () == symVal)
39003916 storeOp.erase ();
39013917 }
3918+ if (auto assignOp =
3919+ mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
3920+ if (assignOp.getLhs () == symVal)
3921+ assignOp.erase ();
3922+ }
39023923 }
39033924 }
39043925 }
0 commit comments