13
13
limitations under the License. */
14
14
15
15
#include < algorithm>
16
- #include < ctime>
17
-
18
16
#include " paddle/fluid/framework/op_registry.h"
19
17
#include " paddle/fluid/framework/var_type.h"
20
- #include " paddle/fluid/framework/variable.h"
21
18
22
19
namespace paddle {
23
20
namespace operators {
21
+ using framework::GradVarName;
24
22
25
23
#define CLOG std::cout
26
24
@@ -35,7 +33,7 @@ struct Formater {
35
33
std::type_index dtype{typeid (const char )};
36
34
framework::LoD lod;
37
35
int summarize;
38
- void * data{nullptr };
36
+ void * data{nullptr };
39
37
40
38
void operator ()(size_t size) {
41
39
PrintMessage ();
@@ -101,7 +99,7 @@ struct Formater {
101
99
102
100
template <typename T>
103
101
void Display (size_t size) {
104
- auto * d = reinterpret_cast <T*>(data);
102
+ auto * d = reinterpret_cast <T *>(data);
105
103
CLOG << " \t data: " ;
106
104
if (summarize != -1 ) {
107
105
summarize = std::min (size, (size_t )summarize);
@@ -120,51 +118,36 @@ struct Formater {
120
118
// TODO(ChunweiYan) there should be some other printers for TensorArray
121
119
class TensorPrintOp : public framework ::OperatorBase {
122
120
public:
123
- TensorPrintOp (const std::string& type,
124
- const framework::VariableNameMap& inputs,
125
- const framework::VariableNameMap& outputs,
126
- const framework::AttributeMap& attrs)
121
+ TensorPrintOp (const std::string & type,
122
+ const framework::VariableNameMap & inputs,
123
+ const framework::VariableNameMap & outputs,
124
+ const framework::AttributeMap & attrs)
127
125
: OperatorBase(type, inputs, outputs, attrs) {}
128
126
129
- TensorPrintOp (const TensorPrintOp& o)
127
+ TensorPrintOp (const TensorPrintOp & o)
130
128
: framework::OperatorBase(
131
- static_cast <const framework::OperatorBase&>(o)) {
129
+ static_cast <const framework::OperatorBase &>(o)) {
132
130
PADDLE_THROW (" Not implemented." );
133
131
}
134
132
135
133
private:
136
- void RunImpl (const framework::Scope& scope,
137
- const platform::Place& place) const override {
138
- const framework::Variable* in_var_ptr = nullptr ;
139
- std::string phase (kForward );
134
+ void RunImpl (const framework::Scope &scope,
135
+ const platform::Place &place) const override {
136
+ const framework::Variable *in_var_ptr = nullptr ;
140
137
std::string printed_var_name = " " ;
141
138
142
- auto & inputs = Inputs ();
143
- if (inputs.find (" In" ) != inputs.end () && !Inputs (" In" ).empty ()) {
144
- in_var_ptr = scope.FindVar (Input (" In" ));
145
- printed_var_name = Inputs (" In" ).front ();
146
- } else if (inputs.find (" In@GRAD" ) != inputs.end () &&
147
- !Inputs (" In@GRAD" ).empty ()) {
148
- in_var_ptr = scope.FindVar (Input (" In@GRAD" ));
149
- printed_var_name = Inputs (" In@GRAD" ).front ();
150
- phase = std::string (kBackward );
151
- } else {
152
- PADDLE_THROW (" Unknown phase, should be forward or backward." );
153
- }
139
+ in_var_ptr = scope.FindVar (Input (" In" ));
140
+ printed_var_name = Inputs (" In" ).front ();
154
141
155
142
PADDLE_ENFORCE_NOT_NULL (in_var_ptr);
156
143
157
- auto & in_tensor = in_var_ptr->Get <framework::LoDTensor>();
158
- auto * out_var_ptr = scope.FindVar (Output (" Out" ));
159
- auto & out_tensor = *out_var_ptr->GetMutable <framework::LoDTensor>();
160
-
161
- // Just copy data from input tensor to output tensor
162
- // output tensor share same memory with input tensor
163
- out_tensor.ShareDataWith (in_tensor);
164
- out_tensor.set_lod (in_tensor.lod ());
144
+ auto &in_tensor = in_var_ptr->Get <framework::LoDTensor>();
165
145
166
146
std::string print_phase = Attr<std::string>(" print_phase" );
167
- if (print_phase != phase && print_phase != std::string (kBoth )) {
147
+ bool is_forward = Attr<bool >(" is_forward" );
148
+
149
+ if ((is_forward && print_phase == kBackward ) ||
150
+ (!is_forward && print_phase == kForward )) {
168
151
return ;
169
152
}
170
153
@@ -192,15 +175,15 @@ class TensorPrintOp : public framework::OperatorBase {
192
175
formater.dtype = printed_tensor.type ();
193
176
}
194
177
if (Attr<bool >(" print_tensor_shape" )) {
195
- auto & dims = printed_tensor.dims ();
178
+ auto & dims = printed_tensor.dims ();
196
179
formater.dims .resize (dims.size ());
197
180
for (int i = 0 ; i < dims.size (); ++i) formater.dims [i] = dims[i];
198
181
}
199
182
if (Attr<bool >(" print_tensor_lod" )) {
200
183
formater.lod = printed_tensor.lod ();
201
184
}
202
185
formater.summarize = Attr<int >(" summarize" );
203
- formater.data = reinterpret_cast <void *>(printed_tensor.data <void >());
186
+ formater.data = reinterpret_cast <void *>(printed_tensor.data <void >());
204
187
formater (printed_tensor.numel ());
205
188
}
206
189
@@ -219,14 +202,14 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
219
202
AddAttr<bool >(" print_tensor_type" , " Whether to print the tensor's dtype." );
220
203
AddAttr<bool >(" print_tensor_shape" , " Whether to print the tensor's shape." );
221
204
AddAttr<bool >(" print_tensor_lod" , " Whether to print the tensor's lod." );
222
- AddAttr<std::string>(
223
- " print_phase " ,
224
- " (string, default 'BOTH') Which phase to display including 'FORWARD' "
225
- " 'BACKWARD' and 'BOTH'." )
205
+ AddAttr<std::string>(" print_phase " ,
206
+ " (string, default 'FORWARD') Which phase to display "
207
+ " including 'FORWARD' "
208
+ " 'BACKWARD' and 'BOTH'." )
226
209
.SetDefault (std::string (kBoth ))
227
210
.InEnum ({std::string (kForward ), std::string (kBackward ),
228
211
std::string (kBoth )});
229
- AddOutput ( " Out " , " Output tensor with same data as input tensor. " );
212
+ AddAttr< bool >( " is_forward " , " Whether is forward or not " ). SetDefault ( true );
230
213
AddComment (R"DOC(
231
214
Creates a print op that will print when a tensor is accessed.
232
215
@@ -238,40 +221,21 @@ tensor `t`.)DOC");
238
221
239
222
class InferShapeForward : public framework ::InferShapeBase {
240
223
public:
241
- void operator ()(framework::InferShapeContext* context) const override {
224
+ void operator ()(framework::InferShapeContext * context) const override {
242
225
PADDLE_ENFORCE (context->HasInput (" In" ), " Input(In) should not be null." );
243
- context->ShareLoD (" In" , /* ->*/ " Out" );
244
- context->SetOutputDim (" Out" , context->GetInputDim (" In" ));
245
- }
246
- };
247
-
248
- class InferShapeBackward : public framework ::InferShapeBase {
249
- public:
250
- void operator ()(framework::InferShapeContext* context) const override {
251
- PADDLE_ENFORCE (context->HasInput (" In@GRAD" ),
252
- " Input(In@GRAD) should not be null." );
253
- context->ShareLoD (" In@GRAD" , /* ->*/ " Out" );
254
- context->SetOutputDim (" Out" , context->GetInputDim (" In@GRAD" ));
255
226
}
256
227
};
257
228
258
- class InferVarType : public framework ::VarTypeInference {
259
- public:
260
- void operator ()(const framework::OpDesc& op_desc,
261
- framework::BlockDesc* block) const override {}
262
- };
263
-
264
- class PrintOpProtoAndCheckGradOpMaker
265
- : public framework::SingleGradOpDescMaker {
229
+ class PrintOpGradientMaker : public framework ::SingleGradOpDescMaker {
266
230
public:
267
231
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
268
232
269
233
std::unique_ptr<framework::OpDesc> Apply () const override {
270
- auto * op_desc_ptr = new framework::OpDesc ();
271
- op_desc_ptr->SetType (" print_grad" );
272
- op_desc_ptr->SetInput (" In@GRAD" , OutputGrad (" Out" ));
273
- op_desc_ptr->SetOutput (" Out" , InputGrad (" In" ));
234
+ auto *op_desc_ptr = new framework::OpDesc ();
235
+ op_desc_ptr->SetType (" print" );
236
+ op_desc_ptr->SetInput (" In" , InputGrad (" In" ));
274
237
op_desc_ptr->SetAttrMap (Attrs ());
238
+ op_desc_ptr->SetAttr (" is_forward" , false );
275
239
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
276
240
}
277
241
};
@@ -282,6 +246,4 @@ class PrintOpProtoAndCheckGradOpMaker
282
246
namespace ops = paddle::operators;
283
247
284
248
REGISTER_OPERATOR (print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker,
285
- ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward,
286
- ops::InferVarType);
287
- REGISTER_OPERATOR (print_grad, ops::TensorPrintOp, ops::InferShapeBackward);
249
+ ops::PrintOpGradientMaker, ops::InferShapeForward);
0 commit comments