@@ -184,21 +184,39 @@ void DefaultPositionInputs::CreateAndInitializePositionIDs(DeviceSpan<int32_t> n
184
184
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
185
185
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
186
186
auto position_ids = OrtValue::CreateTensor (model_.allocator_cpu_ , shape, type_);
187
- auto position_ids_next = OrtValue::CreateTensor (model_.allocator_cpu_ , std::array<int64_t , 2 >{shape[0 ], 1 }, type_);
188
187
auto * position_data = position_ids->GetTensorMutableData <T>();
188
+ auto position_ids_next = OrtValue::CreateTensor (model_.allocator_cpu_ , std::array<int64_t , 2 >{shape[0 ], 1 }, type_);
189
189
auto * position_data_next = position_ids_next->GetTensorMutableData <T>();
190
- const auto * word_id = const_cast <DeviceSpan<int32_t >&>(next_tokens).CpuSpan ().data ();
191
- auto * position = position_data;
192
- for (int i = 0 ; i < shape[0 ]; i++) {
193
- T abs_position = 0 ;
194
- for (int j = 0 ; j < shape[1 ]; j++, word_id++, position++) {
195
- if (*word_id == model_.config_ ->model .pad_token_id ) {
196
- *position = 0 ;
197
- } else {
198
- *position = abs_position++;
190
+ // If batch_size is 1 we have no padding, so we do simple ascending
191
+ if (shape[0 ] == 1 ) {
192
+ for (int i = 0 ; i < shape[1 ]; ++i) {
193
+ position_data[i] = static_cast <T>(i);
194
+ }
195
+ position_data_next[0 ] = static_cast <T>(shape[1 ]) - 1 ;
196
+ // Otherwise we iterate backwards as to not misinterpret any right pad tokens
197
+ } else {
198
+ const auto * word_id = const_cast <DeviceSpan<int32_t >&>(next_tokens).CpuSpan ().data () + shape[0 ] * shape[1 ] - 1 ;
199
+ auto * position = position_data + shape[0 ] * shape[1 ] - 1 ;
200
+ bool found_first_non_pad = false ;
201
+ for (int i = static_cast <int >(shape[0 ] - 1 ); i >= 0 ; i--) {
202
+ T abs_position = static_cast <T>(shape[1 ] - 1 );
203
+ found_first_non_pad = false ;
204
+ for (int j = static_cast <int >(shape[1 ] - 1 ); j >= 0 ; j--, word_id--, position--) {
205
+ // Non-pad tokens are set to their corresponding position
206
+ if (found_first_non_pad) {
207
+ *position = abs_position;
208
+ // If we found first non-padding token, we can now set the rest of the positions to non-0 values
209
+ } else if (*word_id != model_.config_ ->model .pad_token_id ) {
210
+ found_first_non_pad = true ;
211
+ *position = abs_position;
212
+ position_data_next[i] = abs_position;
213
+ // We have not found any non-padding token yet so we set the position to 0
214
+ } else {
215
+ *position = 0 ;
216
+ }
217
+ abs_position--;
199
218
}
200
219
}
201
- position_data_next[i] = abs_position - 1 ;
202
220
}
203
221
204
222
// Move tensors to appropriate device and expand by num_beams
@@ -247,14 +265,27 @@ void DefaultPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan<int32_t>
247
265
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
248
266
auto attention_mask = OrtValue::CreateTensor (model_.allocator_cpu_ , shape, type_);
249
267
auto * mask_data = attention_mask->GetTensorMutableData <T>();
250
- const auto * word_id = const_cast <DeviceSpan<int32_t >&>(next_tokens).CpuSpan ().data ();
251
- auto * mask = mask_data;
252
- for (int i = 0 ; i < shape[0 ]; i++) {
253
- for (int j = 0 ; j < shape[1 ]; j++, word_id++, mask++) {
254
- if (*word_id == model_.config_ ->model .pad_token_id ) {
255
- *mask = 0 ;
256
- } else {
257
- *mask = 1 ;
268
+ // If batch size is 1, we have no padding, so we simply set all tokens to 1
269
+ if (shape[0 ] == 1 ) {
270
+ for (int i = 0 ; i < shape[1 ]; ++i) {
271
+ mask_data[i] = 1 ;
272
+ }
273
+ // Otherwise we iterate backwards as to not misinterpret any right pad tokens
274
+ } else {
275
+ auto * mask = mask_data + shape[0 ] * shape[1 ] - 1 ;
276
+ const auto * word_id = const_cast <DeviceSpan<int32_t >&>(next_tokens).CpuSpan ().data () + shape[0 ] * shape[1 ] - 1 ;
277
+ bool found_first_non_pad = false ;
278
+ for (int i = static_cast <int >(shape[0 ] - 1 ); i >= 0 ; i--) {
279
+ found_first_non_pad = false ;
280
+ for (int j = static_cast <int >(shape[1 ] - 1 ); j >= 0 ; j--, word_id--, mask--) {
281
+ if (found_first_non_pad) {
282
+ *mask = 1 ;
283
+ } else if (*word_id != model_.config_ ->model .pad_token_id ) {
284
+ found_first_non_pad = true ;
285
+ *mask = 1 ;
286
+ } else {
287
+ *mask = 0 ;
288
+ }
258
289
}
259
290
}
260
291
}
0 commit comments