Skip to content

Commit 598f63a

Browse files
authored
Fix issue #939 - tokens batch_type may exceed max_batch_size (#1948)
* 1. Fix the batching logic to include padding tokens in batch size increment in BatchReader.get_next method. The rebatch_input will always pass batch_increment_is_fixed=true. Since rebatch_input sorts the input by length in descending order, the first example in every batch will be longest, so batch increment will be fixed with the longest example in batch length. This solves the issue #939. But since the batch_increment_is_fixed=false by default, it won't affect the prefetching logic as mentioned in the revert of the previous PR addressing this issue here: #1314. 2. Fix the same issue in the CTranslate2/python/ctranslate2/extensions.py module in the _batch_iterator method. 3. Add tests for both changes * add a comment * 1. Improve the get_next method implementation by allowing it to work with unsorted examples input. 2. Fix memory over-allocation in case batch_type=tokens, if we reserve max_batch_size memory for the batch vector, we probably overallocate memory, thus shrink_to_fit is needed before we return the batch. * 1. Rename batch_size_increment_is_fixed var to consider_padding 2. Update documentation
1 parent 71cdf3a commit 598f63a

File tree

7 files changed

+218
-37
lines changed

7 files changed

+218
-37
lines changed

include/ctranslate2/batch_reader.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ namespace ctranslate2 {
5656

5757
std::vector<Example>
5858
get_next(const size_t max_batch_size,
59-
const BatchType batch_type = BatchType::Examples);
59+
const BatchType batch_type = BatchType::Examples,
60+
const bool consider_padding = false);
6061

6162
// Consumes and returns the next example.
6263
virtual Example get_next_example() = 0;
@@ -67,6 +68,12 @@ namespace ctranslate2 {
6768
}
6869

6970
private:
71+
std::vector<Example> fill_batch_with_fixed_increment(const size_t max_batch_size,
72+
const BatchType batch_type);
73+
74+
std::vector<Example> fill_batch_with_variable_increment(const size_t max_batch_size,
75+
const BatchType batch_type);
76+
7077
bool _initialized = false;
7178
Example _next;
7279
};

python/cpp/generator.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,10 @@ namespace ctranslate2 {
234234
Arguments:
235235
start_tokens: Batch of start tokens. If the decoder starts from a special
236236
start token like ``<s>``, this token should be added to this input.
237-
max_batch_size: The maximum batch size. If the number of inputs is greater than
238-
:obj:`max_batch_size`, the inputs are sorted by length and split by chunks of
239-
:obj:`max_batch_size` examples so that the number of padding positions is
240-
minimized.
237+
max_batch_size: The maximum batch size. If the number of inputs is greater than :obj:`max_batch_size`,
238+
the inputs are sorted by length and split by chunks of :obj:`max_batch_size` examples
239+
(or tokens when :obj:`batch_type`="tokens") so that the number of padding positions
240+
is minimized.
241241
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
242242
asynchronous: Run the generation asynchronously.
243243
beam_size: Beam size (1 for greedy search).

python/cpp/translator.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,10 @@ namespace ctranslate2 {
372372
Arguments:
373373
source: Batch of source tokens.
374374
target_prefix: Optional batch of target prefix tokens.
375-
max_batch_size: The maximum batch size. If the number of inputs is greater than
376-
:obj:`max_batch_size`, the inputs are sorted by length and split by chunks of
377-
:obj:`max_batch_size` examples so that the number of padding positions is
378-
minimized.
375+
max_batch_size: The maximum batch size. If the number of inputs is greater than :obj:`max_batch_size`,
376+
the inputs are sorted by length and split by chunks of :obj:`max_batch_size` examples
377+
(or tokens when :obj:`batch_type`="tokens") so that the number of padding positions
378+
is minimized.
379379
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
380380
asynchronous: Run the translation asynchronously.
381381
beam_size: Beam size (1 for greedy search).

python/ctranslate2/extensions.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -556,30 +556,34 @@ def _process_iterable(process_func, iterables, max_batch_size, batch_type, **kwa
556556

557557
def _batch_iterator(iterable, batch_size, batch_type):
558558
streams = None
559-
cur_batch_size = 0
559+
max_length = 0
560560

561561
for example in iterable:
562562
if not isinstance(example, tuple):
563563
example = (example,)
564564

565+
if batch_type == "examples":
566+
if streams and len(streams[0]) == batch_size:
567+
yield streams
568+
streams = None
569+
570+
elif batch_type == "tokens":
571+
max_length = max(max_length, len(example[0]))
572+
573+
if streams and (len(streams[0]) + 1) * max_length > batch_size:
574+
yield streams
575+
streams = None
576+
max_length = len(example[0])
577+
578+
else:
579+
raise ValueError("Invalid batch type %s" % batch_type)
580+
565581
if streams is None:
566582
streams = tuple([] for _ in example)
567583
for batch, element in zip(streams, example):
568584
if element is None and len(streams) > 1:
569585
raise ValueError("Input iterables do not have the same length")
570586
batch.append(element)
571587

572-
if batch_type == "examples":
573-
cur_batch_size += 1
574-
elif batch_type == "tokens":
575-
cur_batch_size += len(example[0])
576-
else:
577-
raise ValueError("Invalid batch type %s" % batch_type)
578-
579-
if cur_batch_size >= batch_size:
580-
yield streams
581-
streams = None
582-
cur_batch_size = 0
583-
584588
if streams is not None:
585589
yield streams

python/tests/test_misc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from ctranslate2.extensions import _batch_iterator as batch_iterator
4+
5+
6+
@pytest.mark.parametrize(
7+
"batch_size,batch_type,lengths,expected_batch_sizes",
8+
[
9+
(2, "examples", [2, 3, 4, 1, 1], [2, 2, 1]),
10+
(6, "tokens", [2, 3, 1, 4, 1, 2], [2, 1, 1, 2]),
11+
],
12+
)
13+
def test_batch_iterator(batch_size, batch_type, lengths, expected_batch_sizes):
14+
iterable = (["a"] * length for length in lengths)
15+
16+
batches = batch_iterator(iterable, batch_size, batch_type)
17+
batch_sizes = [len(batch[0]) for batch in batches]
18+
19+
assert batch_sizes == expected_batch_sizes

src/batch_reader.cc

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,69 @@ namespace ctranslate2 {
3636
}
3737

3838
std::vector<Example>
39-
BatchReader::get_next(const size_t max_batch_size,
40-
const BatchType batch_type) {
41-
if (max_batch_size == 0)
42-
throw std::invalid_argument("BatchReader: max_batch_size must be > 0");
39+
BatchReader::fill_batch_with_fixed_increment(const size_t max_batch_size,
40+
const BatchType batch_type) {
41+
std::vector<Example> batch;
42+
batch.reserve(max_batch_size);
4343

44-
if (!_initialized) {
44+
size_t max_increment = 0;
45+
46+
while (!_next.empty()) {
47+
const size_t cur_increment = get_batch_size_increment(_next, batch_type);
48+
max_increment = std::max(max_increment, cur_increment);
49+
const size_t new_batch_size = (batch.size() + 1) * max_increment;
50+
51+
if (!batch.empty() && new_batch_size > max_batch_size)
52+
break;
53+
54+
batch.emplace_back(std::move(_next));
4555
_next = get_next_example();
46-
_initialized = true;
4756
}
57+
return batch;
58+
}
4859

60+
std::vector<Example>
61+
BatchReader::fill_batch_with_variable_increment(const size_t max_batch_size,
62+
const BatchType batch_type) {
4963
std::vector<Example> batch;
50-
if (_next.empty())
51-
return batch;
52-
5364
batch.reserve(max_batch_size);
5465

55-
size_t batch_size = 0;
66+
size_t total_increment = 0;
5667

5768
while (!_next.empty()) {
58-
const size_t batch_size_increment = get_batch_size_increment(_next, batch_type);
59-
if (batch_size > 0 && batch_size + batch_size_increment > max_batch_size)
69+
const size_t cur_increment = get_batch_size_increment(_next, batch_type);
70+
const size_t new_batch_size = total_increment + cur_increment;
71+
72+
if (!batch.empty() && new_batch_size > max_batch_size)
6073
break;
74+
6175
batch.emplace_back(std::move(_next));
62-
batch_size += batch_size_increment;
76+
total_increment += cur_increment;
77+
_next = get_next_example();
78+
}
79+
return batch;
80+
}
81+
82+
std::vector<Example>
83+
BatchReader::get_next(const size_t max_batch_size,
84+
const BatchType batch_type,
85+
const bool consider_padding) {
86+
if (max_batch_size == 0)
87+
throw std::invalid_argument("BatchReader: max_batch_size must be > 0");
88+
89+
if (!_initialized) {
6390
_next = get_next_example();
91+
_initialized = true;
6492
}
6593

94+
if (_next.empty())
95+
return {};
96+
97+
auto batch = consider_padding
98+
? fill_batch_with_fixed_increment(max_batch_size, batch_type)
99+
: fill_batch_with_variable_increment(max_batch_size, batch_type);
100+
101+
batch.shrink_to_fit();
66102
return batch;
67103
}
68104

@@ -170,7 +206,8 @@ namespace ctranslate2 {
170206
VectorReader batch_reader(index_vector(examples, example_index));
171207

172208
for (size_t offset = 0;;) {
173-
auto examples_part = batch_reader.get_next(max_batch_size, batch_type);
209+
// the batch size increment per example is always fixed because padding is required
210+
auto examples_part = batch_reader.get_next(max_batch_size, batch_type, true);
174211
if (examples_part.empty())
175212
break;
176213

@@ -189,4 +226,4 @@ namespace ctranslate2 {
189226
return batches;
190227
}
191228

192-
}
229+
}

tests/batching_test.cc

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,117 @@ TEST(BatchingTest, RebatchInput) {
3636
EXPECT_EQ(batch.example_index, expected_batches[i]);
3737
}
3838
}
39+
40+
TEST(BatchingTest, BatchReaderGetNext_Examples) {
41+
const std::vector<std::vector<std::string>> examples = {
42+
{"a", "b"},
43+
{"a", "b", "c"},
44+
{"a"},
45+
{"a", "b", "c", "d"}
46+
};
47+
const std::vector<std::vector<size_t>> expected_batches = {{0, 1}, {2, 3}};
48+
49+
VectorReader reader(examples);
50+
51+
for (const auto& expected_batch : expected_batches) {
52+
auto batch = reader.get_next(2, BatchType::Examples, true);
53+
ASSERT_EQ(batch.size(), expected_batch.size());
54+
for (size_t i = 0; i < batch.size(); ++i) {
55+
EXPECT_EQ(batch[i].streams[0], examples[expected_batch[i]]);
56+
}
57+
}
58+
}
59+
60+
TEST(BatchingTest, BatchReaderGetNext_TokensFixed) {
61+
const std::vector<std::vector<std::string>> source = {
62+
{"a", "b", "c", "d"},
63+
{"a", "b", "c", "d", "e"},
64+
{"a"},
65+
{"a", "b", "c"},
66+
{"a", "b"}
67+
};
68+
const std::vector<std::vector<std::string>> target = {
69+
{"1"},
70+
{"2"},
71+
{"3"},
72+
{"4"},
73+
{"5"}
74+
};
75+
76+
const std::vector<std::vector<size_t>> expected_batches = {{1}, {0}, {3, 4}, {2}};
77+
78+
const auto batches = rebatch_input(load_examples({source, target}), 6, BatchType::Tokens);
79+
ASSERT_EQ(batches.size(), expected_batches.size());
80+
81+
for (size_t i = 0; i < batches.size(); ++i) {
82+
const auto& batch = batches[i];
83+
EXPECT_EQ(batch.get_stream(0), index_vector(source, expected_batches[i]));
84+
EXPECT_EQ(batch.get_stream(1), index_vector(target, expected_batches[i]));
85+
EXPECT_EQ(batch.example_index, expected_batches[i]);
86+
}
87+
}
88+
89+
TEST(BatchingTest, BatchReaderGetNext_TokensDynamic) {
90+
const std::vector<std::vector<std::string>> examples = {
91+
{"a", "b"},
92+
{"a", "b", "c"},
93+
{"a"},
94+
{"a", "b", "c", "d"},
95+
{"a", "b", "c", "d", "e"}
96+
};
97+
98+
const std::vector<std::vector<size_t>> expected_batches = {{0, 1, 2}, {3}, {4}};
99+
100+
VectorReader reader(examples);
101+
102+
for (const auto& expected_batch : expected_batches) {
103+
auto batch = reader.get_next(6, BatchType::Tokens, false);
104+
ASSERT_EQ(batch.size(), expected_batch.size());
105+
for (size_t i = 0; i < batch.size(); ++i) {
106+
EXPECT_EQ(batch[i].streams[0], examples[expected_batch[i]]);
107+
}
108+
}
109+
}
110+
111+
TEST(BatchingTest, BatchReaderGetNext_TokensFixed2) {
112+
const std::vector<std::vector<std::string>> source = {
113+
{"a", "b", "c", "d", "e"},
114+
{"a", "b"},
115+
{"a"}
116+
};
117+
const std::vector<std::vector<std::string>> target = {
118+
{"1"},
119+
{"2"},
120+
{"3"}
121+
};
122+
123+
const std::vector<std::vector<size_t>> expected_batches = {{0}, {1, 2}};
124+
const auto batches = rebatch_input(load_examples({source, target}), 8, BatchType::Tokens);
125+
ASSERT_EQ(batches.size(), expected_batches.size());
126+
127+
for (size_t i = 0; i < batches.size(); ++i) {
128+
const auto& batch = batches[i];
129+
EXPECT_EQ(batch.get_stream(0), index_vector(source, expected_batches[i]));
130+
EXPECT_EQ(batch.get_stream(1), index_vector(target, expected_batches[i]));
131+
EXPECT_EQ(batch.example_index, expected_batches[i]);
132+
}
133+
}
134+
135+
TEST(BatchingTest, BatchReaderGetNext_TokensDynamic2) {
136+
const std::vector<std::vector<std::string>> source = {
137+
{"a", "b", "c", "d", "e"},
138+
{"a", "b"},
139+
{"a"}
140+
};
141+
142+
const std::vector<std::vector<size_t>> expected_batches = {{0, 1, 2}};
143+
VectorReader reader(source);
144+
145+
for (const auto& expected_batch : expected_batches) {
146+
auto batch = reader.get_next(8, BatchType::Tokens, false);
147+
ASSERT_EQ(batch.size(), expected_batch.size());
148+
for (size_t i = 0; i < batch.size(); ++i) {
149+
EXPECT_EQ(batch[i].streams[0], source[expected_batch[i]]);
150+
}
151+
}
152+
}

0 commit comments

Comments
 (0)