Skip to content

Commit c737116

Browse files
authored
[Cherry Pick] Add error info during compile (#19300)
* Add call stack info during runtime and compile time test=develop
1 parent 737f21b commit c737116

File tree

6 files changed

+163
-41
lines changed

6 files changed

+163
-41
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
125125
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
126126
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
127127
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
128-
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
128+
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
129129

130130
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
131131

@@ -136,6 +136,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
136136

137137
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
138138

139+
cc_library(op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce)
140+
139141
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
140142

141143
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_call_stack.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/framework/attribute.h"
19+
#include "paddle/fluid/framework/op_proto_maker.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
24+
void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
25+
platform::EnforceNotMet *exception) {
26+
if (attrs.count("sub_block") != 0) {
27+
return;
28+
}
29+
auto &callstack = boost::get<std::vector<std::string>>(
30+
attrs.at(OpProtoAndCheckerMaker::OpCreationCallstackAttrName()));
31+
32+
if (callstack.empty()) {
33+
return;
34+
}
35+
std::ostringstream sout;
36+
sout << "Invoke operator " << type << " error.\n";
37+
sout << "Python Call stacks: \n";
38+
for (auto &line : callstack) {
39+
sout << line;
40+
}
41+
sout << "C++ Call stacks: \n";
42+
sout << exception->err_str_;
43+
exception->err_str_ = sout.str();
44+
}
45+
46+
} // namespace framework
47+
} // namespace paddle
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/type_defs.h"
19+
#include "paddle/fluid/platform/enforce.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
24+
platform::EnforceNotMet *exception);
25+
} // namespace framework
26+
} // namespace paddle

paddle/fluid/framework/op_desc.cc

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ limitations under the License. */
1818
#include <mutex> // NOLINT
1919
#include <string>
2020
#include <unordered_map>
21+
#include <utility>
2122
#include "glog/logging.h"
2223
#include "paddle/fluid/framework/block_desc.h"
24+
#include "paddle/fluid/framework/op_call_stack.h"
2325
#include "paddle/fluid/framework/op_proto_maker.h"
2426
#include "paddle/fluid/framework/operator.h"
2527
#include "paddle/fluid/framework/program_desc.h"
@@ -679,26 +681,33 @@ void OpDesc::CheckAttrs() {
679681
}
680682

681683
void OpDesc::InferShape(const BlockDesc &block) const {
682-
VLOG(3) << "CompileTime infer shape on " << Type();
683-
InitInferShapeFuncs();
684-
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
685-
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
686-
"%s's infer_shape has not been registered", this->Type());
687-
CompileTimeInferShapeContext ctx(*this, block);
688-
if (VLOG_IS_ON(10)) {
689-
std::ostringstream sout;
690-
auto inames = this->InputArgumentNames();
691-
sout << " From [";
692-
std::copy(inames.begin(), inames.end(),
693-
std::ostream_iterator<std::string>(sout, ", "));
694-
sout << "] to [";
695-
auto onames = this->OutputArgumentNames();
696-
std::copy(onames.begin(), onames.end(),
697-
std::ostream_iterator<std::string>(sout, ", "));
698-
sout << "]";
699-
VLOG(10) << sout.str();
700-
}
701-
infer_shape(&ctx);
684+
try {
685+
VLOG(3) << "CompileTime infer shape on " << Type();
686+
InitInferShapeFuncs();
687+
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
688+
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
689+
"%s's infer_shape has not been registered", this->Type());
690+
CompileTimeInferShapeContext ctx(*this, block);
691+
if (VLOG_IS_ON(10)) {
692+
std::ostringstream sout;
693+
auto inames = this->InputArgumentNames();
694+
sout << " From [";
695+
std::copy(inames.begin(), inames.end(),
696+
std::ostream_iterator<std::string>(sout, ", "));
697+
sout << "] to [";
698+
auto onames = this->OutputArgumentNames();
699+
std::copy(onames.begin(), onames.end(),
700+
std::ostream_iterator<std::string>(sout, ", "));
701+
sout << "]";
702+
VLOG(10) << sout.str();
703+
}
704+
infer_shape(&ctx);
705+
} catch (platform::EnforceNotMet exception) {
706+
framework::InsertCallStackInfo(Type(), attrs_, &exception);
707+
throw std::move(exception);
708+
} catch (...) {
709+
std::rethrow_exception(std::current_exception());
710+
}
702711
}
703712

704713
void OpDesc::InferVarType(BlockDesc *block) const {

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/fluid/framework/data_transform.h"
2424
#include "paddle/fluid/framework/executor.h"
2525
#include "paddle/fluid/framework/lod_tensor.h"
26+
#include "paddle/fluid/framework/op_call_stack.h"
2627
#include "paddle/fluid/framework/op_proto_maker.h"
2728
#include "paddle/fluid/framework/operator.h"
2829
#include "paddle/fluid/framework/shape_inference.h"
@@ -183,28 +184,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
183184
} else {
184185
RunImpl(scope, place);
185186
}
186-
187187
VLOG(3) << place << " " << DebugStringEx(&scope);
188188
} catch (platform::EnforceNotMet exception) {
189-
if (Attrs().count("sub_block") != 0) {
190-
throw std::move(exception);
191-
}
192-
193-
auto& callstack = Attr<std::vector<std::string>>(
194-
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
195-
196-
if (callstack.empty()) {
197-
throw std::move(exception);
198-
}
199-
std::ostringstream sout;
200-
sout << "Invoke operator " << Type() << " error.\n";
201-
sout << "Python Callstacks: \n";
202-
for (auto& line : callstack) {
203-
sout << line;
204-
}
205-
sout << "C++ Callstacks: \n";
206-
sout << exception.err_str_;
207-
exception.err_str_ = sout.str();
189+
framework::InsertCallStackInfo(Type(), Attrs(), &exception);
208190
throw std::move(exception);
209191
} catch (...) {
210192
std::rethrow_exception(std::current_exception());
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from op_test import OpTest
20+
import paddle.fluid as fluid
21+
import paddle.fluid.core as core
22+
23+
24+
class TestRunTimeException(OpTest):
25+
def test_run_time_exception(self):
26+
place = fluid.CPUPlace()
27+
exe = fluid.Executor(place)
28+
29+
train_program = fluid.Program()
30+
startup_program = fluid.Program()
31+
with fluid.program_guard(train_program, startup_program):
32+
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
33+
fluid.layers.one_hot(input=label, depth=100)
34+
35+
def _run_program():
36+
x = np.random.random(size=(10)).astype('int64')
37+
exe.run(train_program, feed={"label": x})
38+
39+
self.assertRaises(core.EnforceNotMet, _run_program)
40+
41+
42+
class TestCompileTimeException(OpTest):
43+
def test_compile_time_exception(self):
44+
self.assertRaises(core.EnforceNotMet, self.build_model)
45+
46+
def build_model(self):
47+
train_program = fluid.Program()
48+
startup_program = fluid.Program()
49+
with fluid.program_guard(train_program, startup_program):
50+
label = fluid.layers.data(
51+
name="label", shape=[1], dtype="int64", append_batch_size=False)
52+
fluid.layers.one_hot(input=label, depth=100)
53+
54+
55+
if __name__ == '__main__':
56+
unittest.main()

0 commit comments

Comments
 (0)