Skip to content

Commit 32a0eb5

Browse files
authored
Pick disable reshape inplace in dygraph (#21486)
* disable reshape inplace in dygraph model; test=develop (#21157) * fix ExecutionContext::HasInput and ExecutionContext::HasOutput depend on the scope structure, test=develop (#20721)
1 parent df2b400 commit 32a0eb5

File tree

2 files changed

+4
-26
lines changed

2 files changed

+4
-26
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -414,34 +414,12 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
414414
}
415415

416416
bool ExecutionContext::HasInput(const std::string& name) const {
417-
if (!op_.HasInputs(name)) {
418-
return false;
419-
}
420-
auto& ins = Inputs(name);
421-
size_t length = ins.size();
422-
if (length == 0) {
423-
return false;
424-
}
425-
PADDLE_ENFORCE_EQ(length, 1UL,
426-
"Input %s should not have more than one inputs", name);
427-
auto arg = ins[0];
428-
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
417+
auto* var = InputVar(name);
429418
return var != nullptr;
430419
}
431420

432421
bool ExecutionContext::HasOutput(const std::string& name) const {
433-
if (!op_.HasOutputs(name)) {
434-
return false;
435-
}
436-
auto& outs = Outputs(name);
437-
size_t length = outs.size();
438-
if (length == 0) {
439-
return false;
440-
}
441-
PADDLE_ENFORCE_EQ(length, 1UL,
442-
"Output %s should not have more than one inputs", name);
443-
auto arg = outs[0];
444-
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
422+
auto* var = OutputVar(name);
445423
return var != nullptr;
446424
}
447425

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9011,8 +9011,8 @@ def get_attr_shape(list_shape):
90119011
actual_shape.stop_gradient = True
90129012
inputs["Shape"] = actual_shape
90139013

9014-
out = x if inplace else helper.create_variable_for_type_inference(
9015-
dtype=x.dtype)
9014+
out = x if inplace and not in_dygraph_mode(
9015+
) else helper.create_variable_for_type_inference(dtype=x.dtype)
90169016
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
90179017
helper.append_op(
90189018
type="reshape2",

0 commit comments

Comments
 (0)