Skip to content

Commit b91190f

Browse files
authored
Improve CheckpointManager (#1170)
1 parent 2b407b5 commit b91190f

31 files changed

+526
-678
lines changed

native/src/fairseq2n/data/map_data_source.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@ using PoolArgType = std::size_t;
2222

2323
namespace fairseq2n::detail {
2424

25-
map_data_source::map_data_source(std::unique_ptr<data_source> &&inner, std::vector<map_fn> &&fns,
26-
std::size_t num_parallel_calls, bool deterministic)
27-
: inner_{std::move(inner)}, map_fns_{std::move(fns)}, num_parallel_calls_{num_parallel_calls},
25+
map_data_source::map_data_source(
26+
std::unique_ptr<data_source> &&inner,
27+
std::vector<map_fn> &&fns,
28+
std::size_t num_parallel_calls,
29+
bool deterministic)
30+
: inner_{std::move(inner)},
31+
map_fns_{std::move(fns)},
32+
num_parallel_calls_{num_parallel_calls},
2833
deterministic_{deterministic || num_parallel_calls == 1},
2934
pool_{conditional_cast<PoolArgType>(deterministic ? 0U : num_parallel_calls)}
3035
{
@@ -50,8 +55,13 @@ map_data_source::next()
5055
do {
5156
// Yield a buffered example.
5257
for (; buffer_pos_ < buffer_.end(); ++buffer_pos_) {
53-
if (*buffer_pos_)
54-
return std::move(*buffer_pos_++);
58+
if (buffer_pos_->has_value()) {
59+
std::optional<data> output = std::exchange(*buffer_pos_, std::nullopt);
60+
61+
++buffer_pos_;
62+
63+
return output;
64+
}
5565
}
5666
// If we have exhausted all buffered examples, try to refill the buffer.
5767
} while (fill_buffer());
@@ -168,6 +178,8 @@ map_data_source::fill_buffer()
168178
buffer_.push_back(std::move(maybe_example));
169179
}
170180

181+
buffer_pos_ = buffer_.begin();
182+
171183
if (buffer_.empty())
172184
return false;
173185

@@ -183,8 +195,6 @@ map_data_source::fill_buffer()
183195
else
184196
parallel_for<std::size_t>(apply_function, buffer_.size());
185197

186-
buffer_pos_ = buffer_.begin();
187-
188198
return true;
189199
}
190200

native/src/fairseq2n/data/map_data_source.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ namespace fairseq2n::detail {
2424

2525
class map_data_source final : public data_source {
2626
public:
27-
explicit map_data_source(std::unique_ptr<data_source> &&inner, std::vector<map_fn> &&fns,
28-
std::size_t num_parallel_calls, bool deterministic_);
27+
explicit map_data_source(
28+
std::unique_ptr<data_source> &&inner,
29+
std::vector<map_fn> &&fns,
30+
std::size_t num_parallel_calls,
31+
bool deterministic_);
2932

3033
std::optional<data>
3134
next() override;

0 commit comments

Comments
 (0)