@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/framework/backward.h"
16
- #include " paddle/fluid/operators/net_op.h"
17
16
18
17
#include < deque>
19
18
#include < list>
@@ -22,7 +21,6 @@ limitations under the License. */
22
21
23
22
#include " paddle/fluid/framework/block_desc.h"
24
23
#include " paddle/fluid/framework/op_registry.h"
25
- #include " paddle/fluid/operators/net_op.h"
26
24
27
25
namespace paddle {
28
26
namespace framework {
@@ -60,12 +58,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
60
58
if (grad_ops.size () == 1 ) {
61
59
return std::move (grad_ops[0 ]);
62
60
} else {
63
- auto net_op = new operators::NetOp ();
64
- for (auto & grad_op : grad_ops) {
65
- net_op->AppendOp (std::move (grad_op));
66
- }
67
- net_op->CompleteAddOp ();
68
- return std::unique_ptr<OperatorBase>(net_op);
61
+ PADDLE_THROW (" Unexpected Branch" );
69
62
}
70
63
}
71
64
@@ -91,10 +84,7 @@ static bool AllInSet(
91
84
}
92
85
93
86
static std::unique_ptr<OperatorBase> NOP () {
94
- auto net_op = new operators::NetOp ();
95
- net_op->SetType (" @NOP@" );
96
- net_op->CompleteAddOp ();
97
- return std::unique_ptr<OperatorBase>(net_op);
87
+ PADDLE_THROW (" Unexpected Branch" );
98
88
}
99
89
100
90
// Get backward operator from a forward operator, a recursive implementation.
@@ -136,110 +126,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
136
126
}
137
127
138
128
// Returned gradient network
139
- auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp ());
140
-
141
- if (forwardOp.IsNetOp ()) {
142
- // Because forwardOp is a net op, it can static_cast.
143
- auto & forwardNet = static_cast <const operators::NetOp&>(forwardOp);
144
-
145
- // Map from output gradient variable name to operator's indices in
146
- // backward net's ops_. That operator generates that variable.
147
- std::unordered_map<std::string, std::vector<size_t >> dup_output_ops;
148
-
149
- size_t local_op_id = 0 ;
150
- // reversely travel forwardNet and collect all duplicate outputs.
151
- for (auto it = forwardNet.ops_ .rbegin (); it != forwardNet.ops_ .rend ();
152
- ++it, ++local_op_id) {
153
- auto & fwd = *it;
154
- auto bwd = BackwardRecursive (*fwd, no_grad_names, grad_to_var, uniq_id);
155
- ForEachVarName (bwd->Outputs (),
156
- [&dup_output_ops, local_op_id](const std::string& out) {
157
- dup_output_ops[out].emplace_back (local_op_id);
158
- return false ;
159
- });
160
- net->AppendOp (std::move (bwd));
161
- }
162
- // Get unique ID for this method.
163
- auto uid = uniq_id++;
164
- // TODO(dzh): more comment
165
- // multiple operators which have the same output (y for example) may
166
- // overwrite the same y variable when backward, special operations are token
167
- // to handle this case. For each duplicate output, rename it to an alias
168
- // (original name with a offset), append an `add` op for its operator,
169
- // and finally sum all the alias variable to the final output variable y.
170
- using Pos = std::pair<size_t , std::unique_ptr<OperatorBase>>;
171
- std::list<Pos> insert_position;
172
- for (auto & dup_output_op : dup_output_ops) {
173
- const std::string& name = dup_output_op.first ;
174
- // duplicate @Empty@ don't need to be added
175
- if (name == kEmptyVarName ) continue ;
176
-
177
- auto & dup_op = dup_output_op.second ;
178
- // no duplicate output
179
- if (dup_op.size () == 1 ) continue ;
180
-
181
- // process the duplicate outputs
182
- std::vector<std::string> dup_outputs;
183
- for (size_t i = 0 ; i < dup_op.size (); ++i) {
184
- // rename each duplicate output to an alias
185
- auto op_offset = dup_op[i];
186
- dup_outputs.push_back (name + " @RENAME@" + std::to_string (uid) + " @" +
187
- std::to_string (i));
188
- net->ops_ [op_offset]->Rename (name, dup_outputs.back ());
189
- }
190
- // collect all the offset for each alias,
191
- // insert a sum operator to add all aliases to output
192
- insert_position.push_back (
193
- {dup_op.back (),
194
- OpRegistry::CreateOp (" sum" , {{" X" , dup_outputs}}, {{" Out" , {name}}},
195
- AttributeMap{})});
196
- }
197
-
198
- // make sure the inserted `sum` ops follow the BFS order.
199
- insert_position.sort (
200
- [](const Pos& l, const Pos& r) { return l.first > r.first ; });
201
-
202
- for (auto & pos : insert_position) {
203
- net->InsertOp (pos.first + 1 , std::move (pos.second ));
204
- }
205
- } else {
206
- std::unique_ptr<OperatorBase> grad_op (
207
- CreateGradOp (forwardOp, no_grad_names, grad_to_var));
208
-
209
- ForEachVarName (grad_op->Inputs (), [&no_grad_names, &net, &grad_op](
210
- const std::string& grad_input) {
211
- if (no_grad_names.count (grad_input)) {
212
- // +1 for \0
213
- std::string prefix = grad_input.substr (
214
- 0 , grad_input.size () - sizeof (kGradVarSuffix ) / sizeof (char ) + 1 );
215
- grad_op->Rename (grad_input, prefix + kZeroVarSuffix );
216
-
217
- // If part of input gradient of that operator is not calculated, fill
218
- // zero variables to that input gradient.
219
- net->AppendOp (OpRegistry::CreateOp (" fill_zeros_like" , {{" X" , {prefix}}},
220
- {{" Out" , {grad_input}}},
221
- AttributeMap{}));
222
- }
223
- return false ;
224
- });
225
-
226
- ForEachVarName (grad_op->Outputs (),
227
- [&no_grad_names, &grad_op](const std::string& grad_output) {
228
- if (no_grad_names.count (grad_output)) {
229
- grad_op->Rename (grad_output, kEmptyVarName );
230
- }
231
- return false ;
232
- });
233
-
234
- if (net->ops_ .empty ()) { // Current no aux op is added to network
235
- return grad_op;
236
- }
237
- net->AppendOp (std::move (grad_op));
238
- }
239
- net->SetType (" @GENERATED_BACKWARD@" );
240
- net->CompleteAddOp ();
241
- return std::unique_ptr<OperatorBase>(
242
- static_cast <OperatorBase*>(net.release ()));
129
+ PADDLE_THROW (" Unexpected Branch" );
243
130
}
244
131
245
132
// See header for comments
0 commit comments