Skip to content

Commit 8814bec

Browse files
emailweixudzhwinter
authored andcommitted
Show argument dimensions with operator::DebugStringEx (#7268)
This can make it easier to locate error.
1 parent a9f7cd3 commit 8814bec

File tree

5 files changed

+28
-7
lines changed

5 files changed

+28
-7
lines changed

paddle/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
111111

112112
for (auto& op_desc : block.AllOps()) {
113113
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
114-
VLOG(3) << op->DebugString();
114+
VLOG(3) << op->DebugStringEx(local_scope);
115115
op->Run(*local_scope, place_);
116116
if (FLAGS_check_nan_inf) {
117117
for (auto& vname : op->OutputVars(true)) {

paddle/framework/operator.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ void UseALL() {
7373
UseCUDNN();
7474
}
7575

76+
static DDim GetDims(const Scope& scope, const std::string& name) {
77+
Variable* var = scope.FindVar(name);
78+
if (var->IsType<LoDTensor>()) {
79+
return var->Get<LoDTensor>().dims();
80+
} else if (var->IsType<SelectedRows>()) {
81+
return var->Get<SelectedRows>().GetCompleteDims();
82+
} else {
83+
return DDim({-1});
84+
}
85+
}
86+
7687
std::string OperatorBase::Input(const std::string& name) const {
7788
auto& ins = Inputs(name);
7889
PADDLE_ENFORCE_LE(ins.size(), 1UL,
@@ -105,14 +116,17 @@ const std::vector<std::string>& OperatorBase::Outputs(
105116
return it->second;
106117
}
107118

108-
std::string OperatorBase::DebugString() const {
119+
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
109120
std::stringstream ss;
110121
ss << "Op(" << type_ << "), inputs:{";
111122
for (auto it = inputs_.begin(); it != inputs_.end();) {
112123
auto& input = *it;
113124
ss << input.first << "[";
114125
for (size_t i = 0; i < input.second.size(); ++i) {
115126
ss << input.second[i];
127+
if (scope) {
128+
ss << "(" << GetDims(*scope, input.second[i]) << ")";
129+
}
116130
if (i != input.second.size() - 1) {
117131
ss << ", ";
118132
}
@@ -129,6 +143,9 @@ std::string OperatorBase::DebugString() const {
129143
ss << output.first << "[";
130144
for (size_t i = 0; i < output.second.size(); ++i) {
131145
ss << output.second[i];
146+
if (scope) {
147+
ss << "(" << GetDims(*scope, output.second[i]) << ")";
148+
}
132149
if (i != output.second.size() - 1) {
133150
ss << ", ";
134151
}

paddle/framework/operator.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ class OperatorBase {
108108
return boost::get<T>(attrs_.at(name));
109109
}
110110

111-
virtual std::string DebugString() const;
111+
/// if scope is not null, also show dimensions of arguments
112+
virtual std::string DebugStringEx(const Scope* scope) const;
113+
114+
std::string DebugString() const { return DebugStringEx(nullptr); }
112115

113116
/// Net will call this function to Run an op.
114117
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;

paddle/operators/net_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) {
5656
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
5757
}
5858

59-
std::string NetOp::DebugString() const {
59+
std::string NetOp::DebugStringEx(const framework::Scope* scope) const {
6060
std::ostringstream os;
61-
os << OperatorBase::DebugString() << std::endl;
61+
os << OperatorBase::DebugStringEx(scope) << std::endl;
6262
for (auto& op : ops_) {
63-
std::istringstream is(op->DebugString());
63+
std::istringstream is(op->DebugStringEx(scope));
6464
for (std::string line; std::getline(is, line);) {
6565
os << " " << line << std::endl;
6666
}

paddle/operators/net_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase {
106106

107107
void CompleteAddOp(bool calculate = true);
108108

109-
std::string DebugString() const override;
109+
std::string DebugStringEx(
110+
const framework::Scope* scope = nullptr) const override;
110111

111112
bool IsNetOp() const override;
112113
std::vector<std::string> OutputVars(bool has_intermediate) const override;

0 commit comments

Comments
 (0)