Skip to content

Commit 60ff05e

Browse files
committed
Merge branch 'luotao1-fix_rnn2_test' into fix/jit/exp
test=develop
2 parents b139b68 + ef09862 commit 60ff05e

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ namespace paddle {
1818
namespace inference {
1919

2020
using namespace framework; // NOLINT
21+
static std::vector<float> result_data;
2122

2223
struct DataRecord {
2324
std::vector<std::vector<std::vector<float>>> link_step_data_all;
2425
std::vector<size_t> lod;
2526
std::vector<std::vector<float>> rnn_link_data;
26-
std::vector<float> result_data;
2727
size_t num_samples; // total number of samples
2828
size_t batch_iter{0};
2929
size_t batch_size{1};
@@ -57,6 +57,7 @@ struct DataRecord {
5757
std::ifstream file(path);
5858
std::string line;
5959
int num_lines = 0;
60+
result_data.clear();
6061
while (std::getline(file, line)) {
6162
num_lines++;
6263
std::vector<std::string> data;
@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
135136

136137
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
137138
// the first inference result
138-
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
139139
PADDLE_ENFORCE_GT(outputs.size(), 0);
140140
size_t size = GetSize(outputs[0]);
141141
PADDLE_ENFORCE_GT(size, 0);
142142
float *result = static_cast<float *>(outputs[0].data.data());
143143
for (size_t i = 0; i < size; i++) {
144-
EXPECT_NEAR(result[i], data.result_data[i], 1e-3);
144+
EXPECT_NEAR(result[i], result_data[i], 1e-3);
145145
}
146146
}
147147
}

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/distributed/grpc_client.h"
16-
1715
#include <sys/time.h>
18-
1916
#include <limits>
2017

2118
#include "glog/logging.h" // For VLOG
2219
#include "paddle/fluid/framework/threadpool.h"
20+
#include "paddle/fluid/operators/distributed/grpc_client.h"
2321
#include "paddle/fluid/operators/distributed/grpc_serde.h"
2422
#include "paddle/fluid/operators/distributed/request_handler.h"
2523
#include "paddle/fluid/platform/profiler.h"
@@ -336,16 +334,22 @@ void GRPCClient::Proceed() {
336334
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
337335
c->Process();
338336
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
337+
// FIXME(gongwb): parse error_details?
339338
LOG(ERROR) << c->GetVarHandlePtr()->String()
340-
<< " meets grpc error:" << c->status_.error_message();
339+
<< " meets grpc error, error_code:" << c->status_.error_code()
340+
<< " error_message:" << c->status_.error_message()
341+
<< " error_details:" << c->status_.error_details();
341342
{
342343
std::lock_guard<std::mutex> lk(sync_mutex_);
343344
ok_ = false;
344345
}
345346
c->Finish(false);
346347
} else {
347348
LOG(FATAL) << c->GetVarHandlePtr()->String()
348-
<< " meets grpc error:" << c->status_.error_message();
349+
<< " meets grpc error, error_code:" << c->status_.error_code()
350+
<< " error_message:" << c->status_.error_message()
351+
<< " error_details:" << c->status_.error_details();
352+
349353
c->Finish(false);
350354
}
351355

0 commit comments

Comments
 (0)