Skip to content

Commit 4c82e45

Browse files
authored
[Cherry-Pick][Dy2St]Support call backward() without params in dy2st (#49812) (#50144)
* [Dy2St]Support call backward() without params in dy2st (#49812) * Support call backward() without params in dy2st * format code * format code
1 parent 8c5e432 commit 4c82e45

File tree

5 files changed

+351
-202
lines changed

5 files changed

+351
-202
lines changed

paddle/fluid/operators/select_output_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase {
9393
void operator()(framework::InferShapeContext *context) const override {
9494
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
9595
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
96-
OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput");
96+
OP_INOUT_CHECK(
97+
context->HasOutputs("Out", true), "Output", "Out", "SelectOutput");
9798
}
9899
};
99100

paddle/fluid/pybind/eager_legacy_custom_python_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
2626
PyObject *kwargs) {
2727
PyThreadState *tstate = nullptr;
2828
try {
29-
auto X = GetTensorListFromArgs("run_program", "X", args, 0, false);
29+
auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
3030
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
31-
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false);
31+
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true);
3232
auto OutScope =
3333
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
3434
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);

0 commit comments

Comments
 (0)