Skip to content

Commit c7a6a1f

Browse files
authored
fix runtime crash when rnn model inference, test=develop (#31833) (#31846)
1 parent d44d173 commit c7a6a1f

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
105105
"merge_lod_tensor",
106106
"equal",
107107
"sequence_pool",
108+
"recurrent",
108109
"lod_reset"};
109110
for (auto* tmp : node->inputs) {
110111
CHECK(tmp->IsOp());

paddle/fluid/operators/recurrent_op.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,10 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
211211
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
212212

213213
auto *program = block->Program();
214-
auto ctx = executor.Prepare(
215-
*program, block->ID(), Attr<std::vector<std::string>>(
216-
kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/);
214+
auto ctx = executor.Prepare(*program, block->ID(),
215+
Attr<std::vector<std::string>>(
216+
kSkipEagerDeletionVars), /*skip_ref_cnt_vars*/
217+
true);
217218

218219
static std::mutex mutex;
219220
std::lock_guard<std::mutex> lock(mutex);
@@ -256,16 +257,6 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
256257
// Link inside::output -> outside::output
257258
// outside::output[seq_offset: seq_offset + 1] = inside::output
258259
executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
259-
if (i > 0) {
260-
LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
261-
Outputs(kOutputs),
262-
[&](const framework::LoDTensor &src_tensor,
263-
framework::LoDTensor *dst_tensor) {
264-
framework::Tensor src_slice =
265-
src_tensor.Slice(seq_offset, seq_offset + 1);
266-
dst_tensor->ShareDataWith(src_slice);
267-
});
268-
}
269260

270261
// Linked now, execute!
271262
executor.RunPreparedContext(ctx.get(), &cur_scope,
@@ -285,6 +276,14 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
285276
// early.
286277
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
287278
});
279+
} else {
280+
LinkTensorWithCallback(
281+
cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
282+
[&](const framework::LoDTensor &src_tensor,
283+
framework::LoDTensor *dst_tensor) {
284+
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
285+
framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
286+
});
288287
}
289288

290289
scopes.ForwardNext();

python/paddle/nn/functional/norm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def batch_norm(x,
189189

190190
if in_dygraph_mode():
191191
# for dygraph need tuple
192-
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
193-
data_format, "use_mkldnn", False, "fuse_with_relu", False,
194-
"use_global_stats", use_global_stats, "trainable_statistics",
195-
trainable_statistics)
192+
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
193+
not training, "data_layout", data_format, "use_mkldnn", False,
194+
"fuse_with_relu", False, "use_global_stats", use_global_stats,
195+
"trainable_statistics", trainable_statistics)
196196
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
197197
x, weight, bias, running_mean, running_var, mean_out, variance_out,
198198
*attrs)
@@ -207,6 +207,7 @@ def batch_norm(x,
207207
attrs = {
208208
"momentum": momentum,
209209
"epsilon": epsilon,
210+
"is_test": not training,
210211
"data_layout": data_format,
211212
"use_mkldnn": False,
212213
"fuse_with_relu": False,

0 commit comments

Comments
 (0)