Skip to content

Commit 3db61dc

Browse files
authored
cherry-pick1.6 simplify master+patch,remove ins when size != merge_size or has conflict slot (#20941)
* simplify master+patch,remove ins when size != merge_size or has conflict slot * test=develop
1 parent 5c3656b commit 3db61dc

File tree

4 files changed

+65
-154
lines changed

4 files changed

+65
-154
lines changed

paddle/fluid/framework/data_set.cc

Lines changed: 56 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <random>
1818
#include <unordered_map>
19+
#include <unordered_set>
1920
#include "google/protobuf/io/zero_copy_stream_impl.h"
2021
#include "google/protobuf/message.h"
2122
#include "google/protobuf/text_format.h"
@@ -45,9 +46,7 @@ DatasetImpl<T>::DatasetImpl() {
4546
fleet_send_batch_size_ = 1024;
4647
fleet_send_sleep_seconds_ = 0;
4748
merge_by_insid_ = false;
48-
erase_duplicate_feas_ = true;
49-
keep_unmerged_ins_ = true;
50-
min_merge_size_ = 2;
49+
merge_size_ = 2;
5150
parse_ins_id_ = false;
5251
parse_content_ = false;
5352
preload_thread_num_ = 0;
@@ -118,15 +117,10 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) {
118117
}
119118

120119
template <typename T>
121-
void DatasetImpl<T>::SetMergeByInsId(
122-
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
123-
int min_merge_size, bool keep_unmerged_ins) {
120+
void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
124121
merge_by_insid_ = true;
125122
parse_ins_id_ = true;
126-
merge_slots_list_ = merge_slot_list;
127-
erase_duplicate_feas_ = erase_duplicate_feas;
128-
min_merge_size_ = min_merge_size;
129-
keep_unmerged_ins_ = keep_unmerged_ins;
123+
merge_size_ = merge_size;
130124
}
131125

132126
template <typename T>
@@ -643,22 +637,11 @@ void MultiSlotDataset::MergeByInsId() {
643637
return;
644638
}
645639
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
646-
std::unordered_map<int, bool> merge_slots;
647640
std::vector<std::string> use_slots;
648-
std::vector<bool> use_slots_is_dense;
649641
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
650642
const auto& slot = multi_slot_desc.slots(i);
651643
if (slot.is_used()) {
652644
use_slots.push_back(slot.name());
653-
use_slots_is_dense.push_back(slot.is_dense());
654-
}
655-
}
656-
for (size_t i = 0; i < use_slots.size(); ++i) {
657-
// currently, we don't merge dense slots
658-
if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(),
659-
use_slots[i]) != merge_slots_list_.end() &&
660-
!use_slots_is_dense[i]) {
661-
merge_slots[i] = true;
662645
}
663646
}
664647
CHECK(multi_output_channel_.size() != 0); // NOLINT
@@ -682,134 +665,82 @@ void MultiSlotDataset::MergeByInsId() {
682665
return a.ins_id_ < b.ins_id_;
683666
});
684667

685-
auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a,
686-
const FeatureItem& b) {
687-
auto& a_sign = a.sign().uint64_feasign_;
688-
auto& b_sign = b.sign().uint64_feasign_;
689-
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
690-
};
691-
auto sort_cmp_float = [&merge_slots](const FeatureItem& a,
692-
const FeatureItem& b) {
693-
auto& a_sign = a.sign().float_feasign_;
694-
auto& b_sign = b.sign().float_feasign_;
695-
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
696-
};
697-
auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a,
698-
const FeatureItem& b) {
699-
if (a.slot() == b.slot() &&
700-
merge_slots.find(a.slot()) == merge_slots.end()) {
701-
return true;
702-
}
703-
auto& a_sign = a.sign().uint64_feasign_;
704-
auto& b_sign = b.sign().uint64_feasign_;
705-
return a_sign == b_sign && a.slot() == b.slot();
706-
};
707-
auto unique_eq_float = [&merge_slots](const FeatureItem& a,
708-
const FeatureItem& b) {
709-
if (a.slot() == b.slot() &&
710-
merge_slots.find(a.slot()) == merge_slots.end()) {
711-
return true;
712-
}
713-
auto& a_sign = a.sign().float_feasign_;
714-
auto& b_sign = b.sign().float_feasign_;
715-
return a_sign == b_sign && a.slot() == b.slot();
716-
};
717-
718668
std::vector<Record> results;
669+
uint64_t drop_ins_num = 0;
670+
std::unordered_set<uint16_t> all_int64;
671+
std::unordered_set<uint16_t> all_float;
672+
std::unordered_set<uint16_t> local_uint64;
673+
std::unordered_set<uint16_t> local_float;
674+
719675
VLOG(3) << "recs.size() " << recs.size();
720676
for (size_t i = 0; i < recs.size();) {
721677
size_t j = i + 1;
722678
while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
723679
j++;
724680
}
725-
if (j - i < min_merge_size_) {
726-
if (keep_unmerged_ins_) {
727-
for (size_t k = i; k < j; ++k) {
728-
results.push_back(std::move(recs[k]));
729-
}
730-
}
681+
if (merge_size_ > 0 && j - i != merge_size_) {
682+
drop_ins_num += j - i;
683+
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
684+
<< ", because merge_size=" << merge_size_;
731685
i = j;
732686
continue;
733687
}
734688

735-
std::vector<FeatureItem> merge_uint64_feasigns;
736-
std::vector<FeatureItem> merge_float_feasigns;
737-
Record rec = std::move(recs[i]);
689+
all_int64.clear();
690+
all_float.clear();
691+
bool has_conflict_slot = false;
692+
uint16_t conflict_slot = 0;
693+
694+
Record rec;
695+
rec.ins_id_ = recs[i].ins_id_;
696+
rec.content_ = recs[i].content_;
738697

739-
for (size_t k = i + 1; k < j; k++) {
698+
for (size_t k = i; k < j; k++) {
699+
local_uint64.clear();
700+
local_float.clear();
740701
for (auto& feature : recs[k].uint64_feasigns_) {
741-
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
742-
merge_uint64_feasigns.push_back(std::move(feature));
702+
uint16_t slot = feature.slot();
703+
if (all_int64.find(slot) != all_int64.end()) {
704+
has_conflict_slot = true;
705+
conflict_slot = slot;
706+
break;
743707
}
708+
local_uint64.insert(slot);
709+
rec.uint64_feasigns_.push_back(std::move(feature));
710+
}
711+
if (has_conflict_slot) {
712+
break;
744713
}
714+
all_int64.insert(local_uint64.begin(), local_uint64.end());
715+
745716
for (auto& feature : recs[k].float_feasigns_) {
746-
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
747-
merge_float_feasigns.push_back(std::move(feature));
717+
uint16_t slot = feature.slot();
718+
if (all_float.find(slot) != all_float.end()) {
719+
has_conflict_slot = true;
720+
conflict_slot = slot;
721+
break;
748722
}
723+
local_float.insert(slot);
724+
rec.float_feasigns_.push_back(std::move(feature));
725+
}
726+
if (has_conflict_slot) {
727+
break;
749728
}
750-
recs[k] = Record();
729+
all_float.insert(local_float.begin(), local_float.end());
751730
}
752-
i = j;
753731

754-
if (!erase_duplicate_feas_) {
755-
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
756-
merge_uint64_feasigns.begin(),
757-
merge_uint64_feasigns.end());
758-
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
759-
merge_float_feasigns.begin(),
760-
merge_float_feasigns.end());
732+
if (has_conflict_slot) {
733+
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
734+
<< ", because conflict_slot=" << use_slots[conflict_slot];
735+
drop_ins_num += j - i;
761736
} else {
762-
std::vector<FeatureItem> not_merge_uint64_feasigns;
763-
std::vector<FeatureItem> not_merge_float_feasigns;
764-
765-
for (auto& feature : rec.uint64_feasigns_) {
766-
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
767-
merge_uint64_feasigns.push_back(std::move(feature));
768-
} else {
769-
not_merge_uint64_feasigns.push_back(std::move(feature));
770-
}
771-
}
772-
for (auto& feature : rec.float_feasigns_) {
773-
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
774-
merge_float_feasigns.push_back(std::move(feature));
775-
} else {
776-
not_merge_float_feasigns.push_back(std::move(feature));
777-
}
778-
}
779-
rec.uint64_feasigns_.clear();
780-
rec.float_feasigns_.clear();
781-
782-
// erase duplicate uint64 feasigns
783-
std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(),
784-
sort_cmp_uint64);
785-
merge_uint64_feasigns.erase(
786-
std::unique(merge_uint64_feasigns.begin(),
787-
merge_uint64_feasigns.end(), unique_eq_uint64),
788-
merge_uint64_feasigns.end());
789-
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
790-
merge_uint64_feasigns.begin(),
791-
merge_uint64_feasigns.end());
792-
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
793-
not_merge_uint64_feasigns.begin(),
794-
not_merge_uint64_feasigns.end());
795-
796-
// erase duplicate float feasigns
797-
std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(),
798-
sort_cmp_float);
799-
merge_float_feasigns.erase(
800-
std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(),
801-
unique_eq_float),
802-
merge_float_feasigns.end());
803-
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
804-
merge_float_feasigns.begin(),
805-
merge_float_feasigns.end());
806-
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
807-
not_merge_float_feasigns.begin(),
808-
not_merge_float_feasigns.end());
737+
results.push_back(std::move(rec));
809738
}
810-
results.push_back(rec);
739+
i = j;
811740
}
741+
std::vector<Record>().swap(recs);
812742
VLOG(3) << "results size " << results.size();
743+
LOG(WARNING) << "total drop ins num: " << drop_ins_num;
813744
results.shrink_to_fit();
814745

815746
auto fleet_ptr = FleetWrapper::GetInstance();

paddle/fluid/framework/data_set.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ class Dataset {
6262
virtual void SetParseInsId(bool parse_ins_id) = 0;
6363
virtual void SetParseContent(bool parse_content) = 0;
6464
// set merge by ins id
65-
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
66-
bool erase_duplicate_feas, int min_merge_size,
67-
bool keep_unmerged_ins) = 0;
65+
virtual void SetMergeByInsId(int merge_size) = 0;
6866
// set fea eval mode
6967
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
7068
// get file list
@@ -149,9 +147,7 @@ class DatasetImpl : public Dataset {
149147
virtual void SetChannelNum(int channel_num);
150148
virtual void SetParseInsId(bool parse_ins_id);
151149
virtual void SetParseContent(bool parse_content);
152-
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
153-
bool erase_duplicate_feas, int min_merge_size,
154-
bool keep_unmerged_ins);
150+
virtual void SetMergeByInsId(int merge_size);
155151

156152
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
157153
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
@@ -219,10 +215,7 @@ class DatasetImpl : public Dataset {
219215
bool merge_by_insid_;
220216
bool parse_ins_id_;
221217
bool parse_content_;
222-
bool erase_duplicate_feas_;
223-
bool keep_unmerged_ins_;
224-
int min_merge_size_;
225-
std::vector<std::string> merge_slots_list_;
218+
int merge_size_;
226219
bool slots_shuffle_fea_eval_ = false;
227220
int preload_thread_num_;
228221
std::mutex global_index_mutex_;

python/paddle/fluid/dataset.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -408,26 +408,13 @@ def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
408408
"""
409409
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
410410

411-
def set_merge_by_lineid(self,
412-
var_list,
413-
erase_duplicate_feas=True,
414-
min_merge_size=2,
415-
keep_unmerged_ins=True):
411+
def set_merge_by_lineid(self, merge_size=2):
416412
"""
417413
Set merge by line id, instances of same line id will be merged after
418414
shuffle, you should parse line id in data generator.
419415
420416
Args:
421-
var_list(list): slots that can be merge. each element in var_list
422-
is Variable. some slots such as show and click, we
423-
usually don't merge them for same line id, so user
424-
should specify which slot can be merged.
425-
erase_duplicate_feas(bool): whether erase duplicate feasigns when
426-
merge. default is True.
427-
min_merge_size(int): minimal size to merge. default is 2.
428-
keep_unmerged_ins(bool): whether to keep unmerged ins, such as
429-
ins with unique id or the num of ins with
430-
same id is less than min_merge_size.
417+
merge_size(int): ins size to merge. default is 2.
431418
432419
Examples:
433420
.. code-block:: python
@@ -437,10 +424,9 @@ def set_merge_by_lineid(self,
437424
dataset.set_merge_by_lineid()
438425
439426
"""
440-
var_name_list = [i.name for i in var_list]
441-
self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas,
442-
min_merge_size, keep_unmerged_ins)
427+
self.dataset.set_merge_by_lineid(merge_size)
443428
self.merge_by_lineid = True
429+
self.parse_ins_id = True
444430

445431
def load_into_memory(self):
446432
"""

python/paddle/fluid/tests/unittests/test_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ def test_in_memory_dataset_run_2(self):
272272
except Exception as e:
273273
self.assertTrue(False)
274274

275-
dataset.set_merge_by_lineid(slots_vars)
275+
dataset.set_merge_by_lineid(2)
276+
dataset.set_parse_ins_id(False)
276277
dataset.set_fleet_send_sleep_seconds(2)
277278
dataset.preload_into_memory()
278279
dataset.wait_preload_done()

0 commit comments

Comments
 (0)