Skip to content

Commit 79bd5f9

Browse files
authored
add slot record dataset (#36200)
1 parent 83578cf commit 79bd5f9

File tree

8 files changed

+622
-46
lines changed

8 files changed

+622
-46
lines changed

paddle/fluid/framework/channel.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,19 @@ class ChannelObject {
157157
p.resize(finished);
158158
return finished;
159159
}
160+
// read once only
161+
size_t ReadOnce(std::vector<T>& p, size_t size) { // NOLINT
162+
if (size == 0) {
163+
return 0;
164+
}
165+
std::unique_lock<std::mutex> lock(mutex_);
166+
p.resize(size);
167+
size_t finished = Read(size, &p[0], lock, true);
168+
p.resize(finished);
169+
Notify();
160170

171+
return finished;
172+
}
161173
size_t ReadAll(std::vector<T>& p) { // NOLINT
162174
p.clear();
163175
size_t finished = 0;
@@ -241,17 +253,21 @@ class ChannelObject {
241253
return !closed_;
242254
}
243255

244-
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock) { // NOLINT
256+
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock, // NOLINT
257+
bool once = false) { // NOLINT
245258
size_t finished = 0;
246259
CHECK(n <= MaxCapacity() - reading_count_);
247260
reading_count_ += n;
248261
while (finished < n && WaitForRead(lock)) {
249-
size_t m = std::min(n - finished, data_.size());
262+
size_t m = (std::min)(n - finished, data_.size());
250263
for (size_t i = 0; i < m; i++) {
251264
p[finished++] = std::move(data_.front());
252265
data_.pop_front();
253266
}
254267
reading_count_ -= m;
268+
if (once && m > 0) {
269+
break;
270+
}
255271
}
256272
reading_count_ -= n - finished;
257273
return finished;

paddle/fluid/framework/data_feed.cc

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() {
3636
return manager;
3737
}
3838

39+
class BufferedLineFileReader {
40+
typedef std::function<bool()> SampleFunc;
41+
static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024;
42+
class FILEReader {
43+
public:
44+
explicit FILEReader(FILE* fp) : fp_(fp) {}
45+
int read(char* buf, int len) { return fread(buf, sizeof(char), len, fp_); }
46+
47+
private:
48+
FILE* fp_;
49+
};
50+
51+
public:
52+
typedef std::function<bool(const std::string&)> LineFunc;
53+
54+
private:
55+
template <typename T>
56+
int read_lines(T* reader, LineFunc func, int skip_lines) {
57+
int lines = 0;
58+
size_t ret = 0;
59+
char* ptr = NULL;
60+
char* eol = NULL;
61+
total_len_ = 0;
62+
error_line_ = 0;
63+
64+
SampleFunc spfunc = get_sample_func();
65+
std::string x;
66+
while (!is_error() && (ret = reader->read(buff_, MAX_FILE_BUFF_SIZE)) > 0) {
67+
total_len_ += ret;
68+
ptr = buff_;
69+
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
70+
while (eol != NULL) {
71+
int size = static_cast<int>((eol - ptr) + 1);
72+
x.append(ptr, size - 1);
73+
++lines;
74+
if (lines > skip_lines && spfunc()) {
75+
if (!func(x)) {
76+
++error_line_;
77+
}
78+
}
79+
80+
x.clear();
81+
ptr += size;
82+
ret -= size;
83+
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
84+
}
85+
if (ret > 0) {
86+
x.append(ptr, ret);
87+
}
88+
}
89+
if (!is_error() && !x.empty()) {
90+
++lines;
91+
if (lines > skip_lines && spfunc()) {
92+
if (!func(x)) {
93+
++error_line_;
94+
}
95+
}
96+
}
97+
return lines;
98+
}
99+
100+
public:
101+
BufferedLineFileReader()
102+
: random_engine_(std::random_device()()),
103+
uniform_distribution_(0.0f, 1.0f) {
104+
total_len_ = 0;
105+
sample_line_ = 0;
106+
buff_ =
107+
reinterpret_cast<char*>(calloc(MAX_FILE_BUFF_SIZE + 1, sizeof(char)));
108+
}
109+
~BufferedLineFileReader() { free(buff_); }
110+
111+
int read_file(FILE* fp, LineFunc func, int skip_lines) {
112+
FILEReader reader(fp);
113+
return read_lines<FILEReader>(&reader, func, skip_lines);
114+
}
115+
uint64_t file_size(void) { return total_len_; }
116+
void set_sample_rate(float r) { sample_rate_ = r; }
117+
size_t get_sample_line() { return sample_line_; }
118+
bool is_error(void) { return (error_line_ > 10); }
119+
120+
private:
121+
SampleFunc get_sample_func() {
122+
if (std::abs(sample_rate_ - 1.0f) < 1e-5f) {
123+
return [this](void) { return true; };
124+
}
125+
return [this](void) {
126+
return (uniform_distribution_(random_engine_) < sample_rate_);
127+
};
128+
}
129+
130+
private:
131+
char* buff_ = nullptr;
132+
uint64_t total_len_ = 0;
133+
134+
std::default_random_engine random_engine_;
135+
std::uniform_real_distribution<float> uniform_distribution_;
136+
float sample_rate_ = 1.0f;
137+
size_t sample_line_ = 0;
138+
size_t error_line_ = 0;
139+
};
39140
void RecordCandidateList::ReSize(size_t length) {
40141
mutex_.lock();
41142
capacity_ = length;
@@ -301,7 +402,7 @@ int InMemoryDataFeed<T>::Next() {
301402
<< ", thread_id=" << thread_id_;
302403
}
303404
} else {
304-
VLOG(3) << "enable heter NEXT: " << offset_index_
405+
VLOG(3) << "enable heter next: " << offset_index_
305406
<< " batch_offsets: " << batch_offsets_.size();
306407
if (offset_index_ >= batch_offsets_.size()) {
307408
VLOG(3) << "offset_index: " << offset_index_
@@ -318,14 +419,7 @@ int InMemoryDataFeed<T>::Next() {
318419
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
319420
<< thread_id_;
320421
}
321-
/*
322-
if (offset_index_ == batch_offsets_.size() - 1) {
323-
std::vector<Record> data;
324-
output_channel_->ReadAll(data);
325-
consume_channel_->Write(std::move(data));
326-
}
327-
*/
328-
VLOG(3) << "#15 enable heter NEXT: " << offset_index_
422+
VLOG(3) << "enable heter next: " << offset_index_
329423
<< " batch_offsets: " << batch_offsets_.size()
330424
<< " baych_size: " << this->batch_size_;
331425
}

0 commit comments

Comments
 (0)