Skip to content

Commit 9e80551

Browse files
authored
support dumping params/grads in transpiler mode (#22490) (#22649)
1 parent 5515597 commit 9e80551

18 files changed

+434
-132
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ else()
6666
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
6767
endif()
6868
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
69+
cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor)
6970

7071
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
7172
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
73+
cc_test(device_worker_test SRCS device_worker_test.cc DEPS device_worker)
7274

7375
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog)
7476

paddle/fluid/framework/device_worker.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,73 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
2323
device_reader_ = data_feed;
2424
}
2525

26+
template <typename T>
27+
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
28+
auto count = tensor->numel();
29+
if (start < 0 || end > count) {
30+
VLOG(3) << "access violation";
31+
return "access violation";
32+
}
33+
std::ostringstream os;
34+
for (int64_t i = start; i < end; i++) {
35+
os << ":" << tensor->data<T>()[i];
36+
}
37+
return os.str();
38+
}
39+
40+
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
41+
int64_t end) {
42+
auto count = tensor->numel();
43+
if (start < 0 || end > count) {
44+
VLOG(3) << "access violation";
45+
return "access violation";
46+
}
47+
std::ostringstream os;
48+
for (int64_t i = start; i < end; i++) {
49+
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
50+
}
51+
return os.str();
52+
}
53+
54+
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
55+
std::string out_val;
56+
if (tensor->type() == proto::VarType::FP32) {
57+
out_val = PrintLodTensorType<float>(tensor, start, end);
58+
} else if (tensor->type() == proto::VarType::INT64) {
59+
out_val = PrintLodTensorIntType(tensor, start, end);
60+
} else if (tensor->type() == proto::VarType::FP64) {
61+
out_val = PrintLodTensorType<double>(tensor, start, end);
62+
} else {
63+
out_val = "unsupported type";
64+
}
65+
return out_val;
66+
}
67+
68+
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
69+
auto& dims = tensor->dims();
70+
if (tensor->lod().size() != 0) {
71+
auto& lod = tensor->lod()[0];
72+
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
73+
} else {
74+
return {index * dims[1], (index + 1) * dims[1]};
75+
}
76+
}
77+
78+
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
79+
auto& dims = tensor->dims();
80+
if (dims.size() != 2) return false;
81+
if (tensor->lod().size() != 0) {
82+
auto& lod = tensor->lod()[0];
83+
if (lod.size() != batch_size + 1) {
84+
return false;
85+
}
86+
} else {
87+
if (dims[0] != static_cast<int>(batch_size)) {
88+
return false;
89+
}
90+
}
91+
return true;
92+
}
93+
2694
} // namespace framework
2795
} // namespace paddle

paddle/fluid/framework/device_worker.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ limitations under the License. */
4545
namespace paddle {
4646
namespace framework {
4747

48+
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end);
49+
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
50+
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
51+
4852
class FleetWrapper;
4953

5054
#define SEC_LOG \
@@ -168,6 +172,8 @@ class HogwildWorker : public CPUWorkerBase {
168172
virtual void Initialize(const TrainerDesc& desc);
169173
virtual void TrainFiles();
170174
virtual void TrainFilesWithProfiler();
175+
virtual void SetNeedDump(bool need_dump_field);
176+
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
171177
virtual void PrintFetchVars();
172178
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
173179
virtual void BindingDataFeedMemory();
@@ -177,13 +183,21 @@ class HogwildWorker : public CPUWorkerBase {
177183
protected:
178184
void CreateThreadOperators(const ProgramDesc& program);
179185
void CreateThreadScope(const ProgramDesc& program);
186+
virtual void DumpParam(const int batch_id);
187+
180188
std::vector<std::string> op_names_;
181189
std::vector<OperatorBase*> ops_;
182190
bool thread_barrier_;
183191
// Scope* thread_scope_;
184192
HogwildWorkerParameter param_;
185193
std::vector<std::string> skip_ops_;
186194
std::map<std::string, int> stat_var_name_map_;
195+
// dump params or grads for debug
196+
bool need_dump_param_;
197+
bool need_dump_field_;
198+
std::vector<std::string> dump_param_;
199+
std::vector<std::string> dump_fields_;
200+
ChannelWriter<std::string> writer_;
187201
};
188202

189203
class DownpourWorker : public HogwildWorker {
@@ -203,13 +217,11 @@ class DownpourWorker : public HogwildWorker {
203217
void PushGradients();
204218
void CollectLabelInfo(size_t table_id);
205219
void AdjustInsWeight();
206-
void DumpParam();
207220
void CopySparseTable();
208221
void CopyDenseTable();
209222
void CopyDenseVars();
210-
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end);
211-
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
212-
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
223+
virtual void DumpParam(const int batch_id);
224+
213225
DownpourWorkerParameter param_;
214226
// copy table
215227
CopyTableConfig copy_table_config_;
@@ -236,16 +248,11 @@ class DownpourWorker : public HogwildWorker {
236248
std::vector<::std::future<int32_t>> push_sparse_status_;
237249
bool dump_slot_;
238250
bool need_to_push_dense_;
239-
bool need_dump_field_;
240-
bool need_dump_param_;
241251
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
242252
float scale_datanorm_;
243253
std::vector<::std::future<int32_t>> push_dense_status_;
244-
std::vector<std::string> dump_fields_;
245-
ChannelWriter<std::string> writer_;
246254
// skipped ops
247255
std::vector<std::string> skip_ops_;
248-
std::vector<std::string> dump_param_;
249256
// just save the value in param_ for easy access
250257
std::map<uint64_t, std::string> label_var_name_;
251258
std::map<uint64_t, std::vector<std::string>> dense_value_names_;

paddle/fluid/framework/device_worker_test.cc

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,66 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/device_worker.h"
1516
#include <gtest/gtest.h>
17+
#include "paddle/fluid/framework/lod_tensor.h"
1618
#include "paddle/fluid/framework/trainer.h"
1719

1820
namespace paddle {
1921
namespace framework {
20-
TEST() {
21-
// create hogwild device worker
22+
TEST(LodTensor, PrintLodTensor) {
23+
LoDTensor tensor1;
24+
tensor1.Resize({2});
25+
tensor1.mutable_data<float>(platform::CPUPlace());
26+
tensor1.data<float>()[0] = 0.2;
27+
tensor1.data<float>()[1] = 0.5;
28+
std::string res = PrintLodTensor(&tensor1, -1, 2);
29+
ASSERT_EQ(res, "access violation");
30+
res = PrintLodTensor(&tensor1, 0, 2);
31+
ASSERT_EQ(res, ":0.2:0.5");
32+
33+
LoDTensor tensor2;
34+
tensor2.Resize({2});
35+
tensor2.mutable_data<int64_t>(platform::CPUPlace());
36+
tensor2.data<int64_t>()[0] = 1;
37+
tensor2.data<int64_t>()[1] = 2;
38+
res = PrintLodTensor(&tensor2, -1, 2);
39+
ASSERT_EQ(res, "access violation");
40+
res = PrintLodTensor(&tensor2, 0, 2);
41+
ASSERT_EQ(res, ":1:2");
42+
43+
LoDTensor tensor3;
44+
tensor3.Resize({2});
45+
tensor3.mutable_data<double>(platform::CPUPlace());
46+
tensor3.data<double>()[0] = 0.1;
47+
tensor3.data<double>()[1] = 0.2;
48+
res = PrintLodTensor(&tensor3, 0, 2);
49+
ASSERT_EQ(res, ":0.1:0.2");
2250
}
51+
52+
TEST(LodTensor, GetTensorBound) {
53+
LoD lod{{0, 2}};
54+
LoDTensor tensor;
55+
tensor.set_lod(lod);
56+
tensor.Resize({2, 1});
57+
tensor.mutable_data<float>(platform::CPUPlace());
58+
tensor.data<float>()[0] = 0;
59+
tensor.data<float>()[1] = 1;
60+
std::pair<int64_t, int64_t> res = GetTensorBound(&tensor, 0);
61+
ASSERT_EQ(res.first, 0);
62+
ASSERT_EQ(res.second, 2);
2363
}
64+
65+
TEST(LodTensor, CheckValidOutput) {
66+
LoD lod{{0, 1, 2}};
67+
LoDTensor tensor;
68+
tensor.set_lod(lod);
69+
tensor.Resize({2, 1});
70+
tensor.mutable_data<float>(platform::CPUPlace());
71+
tensor.data<float>()[0] = 0;
72+
tensor.data<float>()[1] = 1;
73+
ASSERT_TRUE(CheckValidOutput(&tensor, 2));
2474
}
75+
76+
} // namespace framework
77+
} // namespace paddle

paddle/fluid/framework/downpour_worker.cc

Lines changed: 6 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -129,89 +129,19 @@ void DownpourWorker::SetNeedDump(bool need_dump_field) {
129129
need_dump_field_ = need_dump_field;
130130
}
131131

132-
template <typename T>
133-
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
134-
auto count = tensor->numel();
135-
if (start < 0 || end > count) {
136-
VLOG(3) << "access violation";
137-
return "access violation";
138-
}
139-
std::ostringstream os;
140-
for (int64_t i = start; i < end; i++) {
141-
os << ":" << tensor->data<T>()[i];
142-
}
143-
return os.str();
144-
}
145-
146-
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
147-
int64_t end) {
148-
auto count = tensor->numel();
149-
if (start < 0 || end > count) {
150-
VLOG(3) << "access violation";
151-
return "access violation";
152-
}
132+
void DownpourWorker::DumpParam(const int batch_id) {
153133
std::ostringstream os;
154-
for (int64_t i = start; i < end; i++) {
155-
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
156-
}
157-
return os.str();
158-
}
159-
160-
std::string DownpourWorker::PrintLodTensor(LoDTensor* tensor, int64_t start,
161-
int64_t end) {
162-
std::string out_val;
163-
if (tensor->type() == proto::VarType::FP32) {
164-
out_val = PrintLodTensorType<float>(tensor, start, end);
165-
} else if (tensor->type() == proto::VarType::INT64) {
166-
out_val = PrintLodTensorIntType(tensor, start, end);
167-
} else if (tensor->type() == proto::VarType::FP64) {
168-
out_val = PrintLodTensorType<double>(tensor, start, end);
169-
} else {
170-
out_val = "unsupported type";
171-
}
172-
return out_val;
173-
}
174-
175-
std::pair<int64_t, int64_t> DownpourWorker::GetTensorBound(LoDTensor* tensor,
176-
int index) {
177-
auto& dims = tensor->dims();
178-
if (tensor->lod().size() != 0) {
179-
auto& lod = tensor->lod()[0];
180-
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
181-
} else {
182-
return {index * dims[1], (index + 1) * dims[1]};
183-
}
184-
}
185-
186-
bool DownpourWorker::CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
187-
auto& dims = tensor->dims();
188-
if (dims.size() != 2) return false;
189-
if (tensor->lod().size() != 0) {
190-
auto& lod = tensor->lod()[0];
191-
if (lod.size() != batch_size + 1) {
192-
return false;
193-
}
194-
} else {
195-
if (dims[0] != static_cast<int>(batch_size)) {
196-
return false;
197-
}
198-
}
199-
return true;
200-
}
201-
202-
void DownpourWorker::DumpParam() {
203-
std::string os;
204134
for (auto& param : dump_param_) {
205-
os.clear();
206-
os = param;
135+
os.str("");
207136
Variable* var = thread_scope_->FindVar(param);
208137
if (var == nullptr) {
209138
continue;
210139
}
211140
LoDTensor* tensor = var->GetMutable<LoDTensor>();
212141
int64_t len = tensor->numel();
213-
os += PrintLodTensor(tensor, 0, len);
214-
writer_ << os;
142+
os << "(" << batch_id << "," << param << ")"
143+
<< PrintLodTensor(tensor, 0, len);
144+
writer_ << os.str();
215145
}
216146
}
217147

@@ -1022,7 +952,7 @@ void DownpourWorker::TrainFiles() {
1022952
writer_ << ars[i];
1023953
}
1024954
if (need_dump_param_ && thread_id_ == 0) {
1025-
DumpParam();
955+
DumpParam(batch_cnt);
1026956
}
1027957
}
1028958

paddle/fluid/framework/downpour_worker_opt.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ void DownpourWorkerOpt::TrainFiles() {
564564
writer_ << ars[i];
565565
}
566566
if (need_dump_param_ && thread_id_ == 0) {
567-
DumpParam();
567+
DumpParam(batch_cnt);
568568
}
569569
}
570570

0 commit comments

Comments
 (0)