Skip to content

Commit 9b80483

Browse files
authored
fix pad vs. eos token misidentification (microsoft#1694)
Addresses this issue: microsoft#1682 regarding the misidentification of eos tokens as padding tokens. For models which share EOS and pad token, each EOS token present in the input would be identified as a pad token, leading to incorrect attention_mask and later parity issues. This PR fixes this issue by adding a special case for batch_size 1 (where input padding is unnecessary and assumed to not be there) as well as iterating through the input backwards in batch_size > 1 scenarios, as to not misinterpret any EOS tokens which are not at the end of the sequence.
1 parent 101f208 commit 9b80483

File tree

2 files changed

+89
-19
lines changed

2 files changed

+89
-19
lines changed

src/models/position_inputs.cpp

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -184,21 +184,39 @@ void DefaultPositionInputs::CreateAndInitializePositionIDs(DeviceSpan<int32_t> n
184184
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
185185
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
186186
auto position_ids = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_);
187-
auto position_ids_next = OrtValue::CreateTensor(model_.allocator_cpu_, std::array<int64_t, 2>{shape[0], 1}, type_);
188187
auto* position_data = position_ids->GetTensorMutableData<T>();
188+
auto position_ids_next = OrtValue::CreateTensor(model_.allocator_cpu_, std::array<int64_t, 2>{shape[0], 1}, type_);
189189
auto* position_data_next = position_ids_next->GetTensorMutableData<T>();
190-
const auto* word_id = const_cast<DeviceSpan<int32_t>&>(next_tokens).CpuSpan().data();
191-
auto* position = position_data;
192-
for (int i = 0; i < shape[0]; i++) {
193-
T abs_position = 0;
194-
for (int j = 0; j < shape[1]; j++, word_id++, position++) {
195-
if (*word_id == model_.config_->model.pad_token_id) {
196-
*position = 0;
197-
} else {
198-
*position = abs_position++;
190+
// If batch_size is 1 we have no padding, so we do simple ascending
191+
if (shape[0] == 1) {
192+
for (int i = 0; i < shape[1]; ++i) {
193+
position_data[i] = static_cast<T>(i);
194+
}
195+
position_data_next[0] = static_cast<T>(shape[1]) - 1;
196+
// Otherwise we iterate backwards as to not misinterpret any right pad tokens
197+
} else {
198+
const auto* word_id = const_cast<DeviceSpan<int32_t>&>(next_tokens).CpuSpan().data() + shape[0] * shape[1] - 1;
199+
auto* position = position_data + shape[0] * shape[1] - 1;
200+
bool found_first_non_pad = false;
201+
for (int i = static_cast<int>(shape[0] - 1); i >= 0; i--) {
202+
T abs_position = static_cast<T>(shape[1] - 1);
203+
found_first_non_pad = false;
204+
for (int j = static_cast<int>(shape[1] - 1); j >= 0; j--, word_id--, position--) {
205+
// Non-pad tokens are set to their corresponding position
206+
if (found_first_non_pad) {
207+
*position = abs_position;
208+
// If we found first non-padding token, we can now set the rest of the positions to non-0 values
209+
} else if (*word_id != model_.config_->model.pad_token_id) {
210+
found_first_non_pad = true;
211+
*position = abs_position;
212+
position_data_next[i] = abs_position;
213+
// We have not found any non-padding token yet so we set the position to 0
214+
} else {
215+
*position = 0;
216+
}
217+
abs_position--;
199218
}
200219
}
201-
position_data_next[i] = abs_position - 1;
202220
}
203221

204222
// Move tensors to appropriate device and expand by num_beams
@@ -247,14 +265,27 @@ void DefaultPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan<int32_t>
247265
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
248266
auto attention_mask = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_);
249267
auto* mask_data = attention_mask->GetTensorMutableData<T>();
250-
const auto* word_id = const_cast<DeviceSpan<int32_t>&>(next_tokens).CpuSpan().data();
251-
auto* mask = mask_data;
252-
for (int i = 0; i < shape[0]; i++) {
253-
for (int j = 0; j < shape[1]; j++, word_id++, mask++) {
254-
if (*word_id == model_.config_->model.pad_token_id) {
255-
*mask = 0;
256-
} else {
257-
*mask = 1;
268+
// If batch size is 1, we have no padding, so we simply set all tokens to 1
269+
if (shape[0] == 1) {
270+
for (int i = 0; i < shape[1]; ++i) {
271+
mask_data[i] = 1;
272+
}
273+
// Otherwise we iterate backwards as to not misinterpret any right pad tokens
274+
} else {
275+
auto* mask = mask_data + shape[0] * shape[1] - 1;
276+
const auto* word_id = const_cast<DeviceSpan<int32_t>&>(next_tokens).CpuSpan().data() + shape[0] * shape[1] - 1;
277+
bool found_first_non_pad = false;
278+
for (int i = static_cast<int>(shape[0] - 1); i >= 0; i--) {
279+
found_first_non_pad = false;
280+
for (int j = static_cast<int>(shape[1] - 1); j >= 0; j--, word_id--, mask--) {
281+
if (found_first_non_pad) {
282+
*mask = 1;
283+
} else if (*word_id != model_.config_->model.pad_token_id) {
284+
found_first_non_pad = true;
285+
*mask = 1;
286+
} else {
287+
*mask = 0;
288+
}
258289
}
259290
}
260291
}

test/c_api_tests.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,45 @@ TEST(CAPITests, EndToEndPhi) {
461461
#endif
462462
}
463463

464+
TEST(CAPITests, EndToEndPhiEOSPAD) {
465+
#if TEST_PHI2
466+
auto model = OgaModel::Create(PHI2_PATH);
467+
auto tokenizer = OgaTokenizer::Create(*model);
468+
469+
const char* input_string = "This is a test.<|endoftext|>";
470+
auto input_sequence = OgaSequences::Create();
471+
tokenizer->Encode(input_string, *input_sequence);
472+
473+
auto params = OgaGeneratorParams::Create(*model);
474+
params->SetSearchOption("max_length", 40);
475+
476+
auto generator = OgaGenerator::Create(*model, *params);
477+
generator->AppendTokenSequences(*input_sequence);
478+
479+
while (!generator->IsDone()) {
480+
generator->GenerateNextToken();
481+
}
482+
483+
// Decode The Batch
484+
auto out_string = tokenizer->Decode(generator->GetSequenceData(0), generator->GetSequenceCount(0));
485+
std::cout << "Decoded string:" << out_string << std::endl;
486+
487+
// Verify outputs match expected outputs
488+
std::vector<int32_t> expected_output{
489+
1212, 318, 257, 1332, 13, 50256, 198, 198, 198, 198, 4010, 4420, 43168, 15666,
490+
10503, 82, 26268, 11451, 12735, 82, 19445, 427, 278, 49292, 3087, 26762, 5101,
491+
14453, 5421, 278, 829, 319, 8378, 8378, 10257, 82, 1028, 1028, 16219, 263};
492+
493+
const auto sequence_length = generator->GetSequenceCount(0);
494+
const auto* sequence_data = generator->GetSequenceData(0);
495+
496+
ASSERT_LE(sequence_length, 40);
497+
498+
const auto* expected_output_start = &expected_output[0];
499+
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t)));
500+
#endif
501+
}
502+
464503
#if ENABLE_ENGINE_TESTS
465504
TEST(CAPIEngineTests, EndToEndPhi) {
466505
auto model = OgaModel::Create(PHI2_PATH);

0 commit comments

Comments
 (0)