Skip to content

Commit 8965cee

Browse files
authored
Polish PrintOp (#12895)
* Polish PrintOp * Polish PrintOp * Polish PrintOp * Refine test_print_op
1 parent 9be39bb commit 8965cee

File tree

4 files changed

+37
-79
lines changed

4 files changed

+37
-79
lines changed

paddle/fluid/framework/var_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace paddle {
2626
namespace framework {
2727

2828
template <typename T>
29-
bool IsType(const std::type_index& type_index) {
29+
inline bool IsType(const std::type_index& type_index) {
3030
return type_index == std::type_index(typeid(T));
3131
}
3232

paddle/fluid/operators/print_op.cc

Lines changed: 33 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
limitations under the License. */
1414

1515
#include <algorithm>
16-
#include <ctime>
17-
1816
#include "paddle/fluid/framework/op_registry.h"
1917
#include "paddle/fluid/framework/var_type.h"
20-
#include "paddle/fluid/framework/variable.h"
2118

2219
namespace paddle {
2320
namespace operators {
21+
using framework::GradVarName;
2422

2523
#define CLOG std::cout
2624

@@ -35,7 +33,7 @@ struct Formater {
3533
std::type_index dtype{typeid(const char)};
3634
framework::LoD lod;
3735
int summarize;
38-
void* data{nullptr};
36+
void *data{nullptr};
3937

4038
void operator()(size_t size) {
4139
PrintMessage();
@@ -101,7 +99,7 @@ struct Formater {
10199

102100
template <typename T>
103101
void Display(size_t size) {
104-
auto* d = reinterpret_cast<T*>(data);
102+
auto *d = reinterpret_cast<T *>(data);
105103
CLOG << "\tdata: ";
106104
if (summarize != -1) {
107105
summarize = std::min(size, (size_t)summarize);
@@ -120,51 +118,36 @@ struct Formater {
120118
// TODO(ChunweiYan) there should be some other printers for TensorArray
121119
class TensorPrintOp : public framework::OperatorBase {
122120
public:
123-
TensorPrintOp(const std::string& type,
124-
const framework::VariableNameMap& inputs,
125-
const framework::VariableNameMap& outputs,
126-
const framework::AttributeMap& attrs)
121+
TensorPrintOp(const std::string &type,
122+
const framework::VariableNameMap &inputs,
123+
const framework::VariableNameMap &outputs,
124+
const framework::AttributeMap &attrs)
127125
: OperatorBase(type, inputs, outputs, attrs) {}
128126

129-
TensorPrintOp(const TensorPrintOp& o)
127+
TensorPrintOp(const TensorPrintOp &o)
130128
: framework::OperatorBase(
131-
static_cast<const framework::OperatorBase&>(o)) {
129+
static_cast<const framework::OperatorBase &>(o)) {
132130
PADDLE_THROW("Not implemented.");
133131
}
134132

135133
private:
136-
void RunImpl(const framework::Scope& scope,
137-
const platform::Place& place) const override {
138-
const framework::Variable* in_var_ptr = nullptr;
139-
std::string phase(kForward);
134+
void RunImpl(const framework::Scope &scope,
135+
const platform::Place &place) const override {
136+
const framework::Variable *in_var_ptr = nullptr;
140137
std::string printed_var_name = "";
141138

142-
auto& inputs = Inputs();
143-
if (inputs.find("In") != inputs.end() && !Inputs("In").empty()) {
144-
in_var_ptr = scope.FindVar(Input("In"));
145-
printed_var_name = Inputs("In").front();
146-
} else if (inputs.find("In@GRAD") != inputs.end() &&
147-
!Inputs("In@GRAD").empty()) {
148-
in_var_ptr = scope.FindVar(Input("In@GRAD"));
149-
printed_var_name = Inputs("In@GRAD").front();
150-
phase = std::string(kBackward);
151-
} else {
152-
PADDLE_THROW("Unknown phase, should be forward or backward.");
153-
}
139+
in_var_ptr = scope.FindVar(Input("In"));
140+
printed_var_name = Inputs("In").front();
154141

155142
PADDLE_ENFORCE_NOT_NULL(in_var_ptr);
156143

157-
auto& in_tensor = in_var_ptr->Get<framework::LoDTensor>();
158-
auto* out_var_ptr = scope.FindVar(Output("Out"));
159-
auto& out_tensor = *out_var_ptr->GetMutable<framework::LoDTensor>();
160-
161-
// Just copy data from input tensor to output tensor
162-
// output tensor share same memory with input tensor
163-
out_tensor.ShareDataWith(in_tensor);
164-
out_tensor.set_lod(in_tensor.lod());
144+
auto &in_tensor = in_var_ptr->Get<framework::LoDTensor>();
165145

166146
std::string print_phase = Attr<std::string>("print_phase");
167-
if (print_phase != phase && print_phase != std::string(kBoth)) {
147+
bool is_forward = Attr<bool>("is_forward");
148+
149+
if ((is_forward && print_phase == kBackward) ||
150+
(!is_forward && print_phase == kForward)) {
168151
return;
169152
}
170153

@@ -192,15 +175,15 @@ class TensorPrintOp : public framework::OperatorBase {
192175
formater.dtype = printed_tensor.type();
193176
}
194177
if (Attr<bool>("print_tensor_shape")) {
195-
auto& dims = printed_tensor.dims();
178+
auto &dims = printed_tensor.dims();
196179
formater.dims.resize(dims.size());
197180
for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i];
198181
}
199182
if (Attr<bool>("print_tensor_lod")) {
200183
formater.lod = printed_tensor.lod();
201184
}
202185
formater.summarize = Attr<int>("summarize");
203-
formater.data = reinterpret_cast<void*>(printed_tensor.data<void>());
186+
formater.data = reinterpret_cast<void *>(printed_tensor.data<void>());
204187
formater(printed_tensor.numel());
205188
}
206189

@@ -219,14 +202,14 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
219202
AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.");
220203
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.");
221204
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.");
222-
AddAttr<std::string>(
223-
"print_phase",
224-
"(string, default 'BOTH') Which phase to display including 'FORWARD' "
225-
"'BACKWARD' and 'BOTH'.")
205+
AddAttr<std::string>("print_phase",
206+
"(string, default 'FORWARD') Which phase to display "
207+
"including 'FORWARD' "
208+
"'BACKWARD' and 'BOTH'.")
226209
.SetDefault(std::string(kBoth))
227210
.InEnum({std::string(kForward), std::string(kBackward),
228211
std::string(kBoth)});
229-
AddOutput("Out", "Output tensor with same data as input tensor.");
212+
AddAttr<bool>("is_forward", "Whether is forward or not").SetDefault(true);
230213
AddComment(R"DOC(
231214
Creates a print op that will print when a tensor is accessed.
232215
@@ -238,40 +221,21 @@ tensor `t`.)DOC");
238221

239222
class InferShapeForward : public framework::InferShapeBase {
240223
public:
241-
void operator()(framework::InferShapeContext* context) const override {
224+
void operator()(framework::InferShapeContext *context) const override {
242225
PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null.");
243-
context->ShareLoD("In", /*->*/ "Out");
244-
context->SetOutputDim("Out", context->GetInputDim("In"));
245-
}
246-
};
247-
248-
class InferShapeBackward : public framework::InferShapeBase {
249-
public:
250-
void operator()(framework::InferShapeContext* context) const override {
251-
PADDLE_ENFORCE(context->HasInput("In@GRAD"),
252-
"Input(In@GRAD) should not be null.");
253-
context->ShareLoD("In@GRAD", /*->*/ "Out");
254-
context->SetOutputDim("Out", context->GetInputDim("In@GRAD"));
255226
}
256227
};
257228

258-
class InferVarType : public framework::VarTypeInference {
259-
public:
260-
void operator()(const framework::OpDesc& op_desc,
261-
framework::BlockDesc* block) const override {}
262-
};
263-
264-
class PrintOpProtoAndCheckGradOpMaker
265-
: public framework::SingleGradOpDescMaker {
229+
class PrintOpGradientMaker : public framework::SingleGradOpDescMaker {
266230
public:
267231
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
268232

269233
std::unique_ptr<framework::OpDesc> Apply() const override {
270-
auto* op_desc_ptr = new framework::OpDesc();
271-
op_desc_ptr->SetType("print_grad");
272-
op_desc_ptr->SetInput("In@GRAD", OutputGrad("Out"));
273-
op_desc_ptr->SetOutput("Out", InputGrad("In"));
234+
auto *op_desc_ptr = new framework::OpDesc();
235+
op_desc_ptr->SetType("print");
236+
op_desc_ptr->SetInput("In", InputGrad("In"));
274237
op_desc_ptr->SetAttrMap(Attrs());
238+
op_desc_ptr->SetAttr("is_forward", false);
275239
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
276240
}
277241
};
@@ -282,6 +246,4 @@ class PrintOpProtoAndCheckGradOpMaker
282246
namespace ops = paddle::operators;
283247

284248
REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker,
285-
ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward,
286-
ops::InferVarType);
287-
REGISTER_OPERATOR(print_grad, ops::TensorPrintOp, ops::InferShapeBackward);
249+
ops::PrintOpGradientMaker, ops::InferShapeForward);

python/paddle/fluid/layers/control_flow.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ def Print(input,
189189
message="The content of some_layer: ")
190190
'''
191191
helper = LayerHelper('print', **locals())
192-
out = helper.create_tmp_variable(dtype=helper.input_dtype())
193192
helper.append_op(
194193
type='print',
195194
inputs={'In': input},
@@ -202,9 +201,7 @@ def Print(input,
202201
'print_tensor_shape': print_tensor_shape,
203202
'print_tensor_lod': print_tensor_lod,
204203
'print_phase': print_phase.upper()
205-
},
206-
outputs={'Out': out})
207-
return out
204+
})
208205

209206

210207
class BlockGuard(object):

python/paddle/fluid/tests/unittests/test_print_op.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ def setUp(self):
3535
def build_network(self, only_forward, **kargs):
3636
x = layers.data('x', shape=[3], dtype='float32', lod_level=1)
3737
x.stop_gradient = False
38-
printed = layers.Print(input=x, **kargs)
39-
if only_forward: return printed
40-
loss = layers.mean(printed)
38+
layers.Print(input=x, **kargs)
39+
loss = layers.mean(x)
4140
append_backward(loss=loss)
4241
return loss
4342

0 commit comments

Comments
 (0)