@@ -1108,6 +1108,100 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
11081108 }
11091109};
11101110
1111+ class EvaluateIntoMemoryAssignBufferization
1112+ : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {
1113+
1114+ public:
1115+ using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;
1116+
1117+ llvm::LogicalResult
1118+ matchAndRewrite (hlfir::EvaluateInMemoryOp,
1119+ mlir::PatternRewriter &rewriter) const override ;
1120+ };
1121+
1122+ static llvm::LogicalResult
1123+ tryUsingAssignLhsDirectly (hlfir::EvaluateInMemoryOp evalInMem,
1124+ mlir::PatternRewriter &rewriter) {
1125+ mlir::Location loc = evalInMem.getLoc ();
1126+ hlfir::DestroyOp destroy;
1127+ hlfir::AssignOp assign;
1128+ for (auto user : llvm::enumerate (evalInMem->getUsers ())) {
1129+ if (user.index () > 2 )
1130+ return mlir::failure ();
1131+ mlir::TypeSwitch<mlir::Operation *, void >(user.value ())
1132+ .Case ([&](hlfir::AssignOp op) { assign = op; })
1133+ .Case ([&](hlfir::DestroyOp op) { destroy = op; });
1134+ }
1135+ if (!assign || !destroy || destroy.mustFinalizeExpr () ||
1136+ assign.isAllocatableAssignment ())
1137+ return mlir::failure ();
1138+
1139+ hlfir::Entity lhs{assign.getLhs ()};
1140+ // EvaluateInMemoryOp memory is contiguous, so in general, it can only be
1141+ // replace by the LHS if the LHS is contiguous.
1142+ if (!lhs.isSimplyContiguous ())
1143+ return mlir::failure ();
1144+ // Character assignment may involves truncation/padding, so the LHS
1145+ // cannot be used to evaluate RHS in place without proving the LHS and
1146+ // RHS lengths are the same.
1147+ if (lhs.isCharacter ())
1148+ return mlir::failure ();
1149+ fir::AliasAnalysis aliasAnalysis;
1150+ // The region must not read or write the LHS.
1151+ // Note that getModRef is used instead of mlir::MemoryEffects because
1152+ // EvaluateInMemoryOp is typically expected to hold fir.calls and that
1153+ // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
1154+ // it is hard/impossible to list all the read/written SSA values in a call,
1155+ // but it is often possible to tell that an SSA value cannot be accessed,
1156+ // hence getModRef is needed here and below. Also note that getModRef uses
1157+ // mlir::MemoryEffects for operations that do not have special handling in
1158+ // getModRef.
1159+ if (aliasAnalysis.getModRef (evalInMem.getBody (), lhs).isModOrRef ())
1160+ return mlir::failure ();
1161+ // Any variables affected between the hlfir.evalInMem and assignment must not
1162+ // be read or written inside the region since it will be moved at the
1163+ // assignment insertion point.
1164+ auto effects = getEffectsBetween (evalInMem->getNextNode (), assign);
1165+ if (!effects) {
1166+ LLVM_DEBUG (
1167+ llvm::dbgs ()
1168+ << " operation with unknown effects between eval_in_mem and assign\n " );
1169+ return mlir::failure ();
1170+ }
1171+ for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
1172+ mlir::Value affected = effect.getValue ();
1173+ if (!affected ||
1174+ aliasAnalysis.getModRef (evalInMem.getBody (), affected).isModOrRef ())
1175+ return mlir::failure ();
1176+ }
1177+
1178+ rewriter.setInsertionPoint (assign);
1179+ fir::FirOpBuilder builder (rewriter, evalInMem.getOperation ());
1180+ mlir::Value rawLhs = hlfir::genVariableRawAddress (loc, builder, lhs);
1181+ hlfir::computeEvaluateOpIn (loc, builder, evalInMem, rawLhs);
1182+ rewriter.eraseOp (assign);
1183+ rewriter.eraseOp (destroy);
1184+ rewriter.eraseOp (evalInMem);
1185+ return mlir::success ();
1186+ }
1187+
1188+ llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite (
1189+ hlfir::EvaluateInMemoryOp evalInMem,
1190+ mlir::PatternRewriter &rewriter) const {
1191+ if (mlir::succeeded (tryUsingAssignLhsDirectly (evalInMem, rewriter)))
1192+ return mlir::success ();
1193+ // Rewrite to temp + as_expr here so that the assign + as_expr pattern can
1194+ // kick-in for simple types and at least implement the assignment inline
1195+ // instead of call Assign runtime.
1196+ fir::FirOpBuilder builder (rewriter, evalInMem.getOperation ());
1197+ mlir::Location loc = evalInMem.getLoc ();
1198+ auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp (
1199+ loc, builder, evalInMem, evalInMem.getShape (), evalInMem.getTypeparams ());
1200+ rewriter.replaceOpWithNewOp <hlfir::AsExprOp>(
1201+ evalInMem, temp, /* mustFree=*/ builder.createBool (loc, isHeapAllocated));
1202+ return mlir::success ();
1203+ }
1204+
11111205class OptimizedBufferizationPass
11121206 : public hlfir::impl::OptimizedBufferizationBase<
11131207 OptimizedBufferizationPass> {
@@ -1130,6 +1224,7 @@ class OptimizedBufferizationPass
11301224 patterns.insert <ElementalAssignBufferization>(context);
11311225 patterns.insert <BroadcastAssignBufferization>(context);
11321226 patterns.insert <VariableAssignBufferization>(context);
1227+ patterns.insert <EvaluateIntoMemoryAssignBufferization>(context);
11331228 patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
11341229 patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
11351230 patterns.insert <ReductionConversion<hlfir::AllOp>>(context);
0 commit comments