|
13 | 13 | from typing_extensions import override |
14 | 14 |
|
15 | 15 | from fairseq2.device import Device, SupportsDeviceTransfer |
| 16 | +from fairseq2.error import InvalidOperationError |
16 | 17 | from fairseq2.nn import BatchLayout |
17 | 18 |
|
18 | 19 |
|
@@ -105,7 +106,7 @@ def __init__( |
105 | 106 | self._padding = 0 |
106 | 107 |
|
107 | 108 | for idx, seq_len in enumerate(seq_lens): |
108 | | - if seq_len < 0: |
| 109 | + if seq_len < 1: |
109 | 110 | raise ValueError( |
110 | 111 | f"All lengths in `seq_lens` must be greater than or equal to 1, but the length at index {idx} is {seq_len} instead." |
111 | 112 | ) |
@@ -163,11 +164,27 @@ def as_auto_regressive(self) -> tuple[SequenceBatch, SequenceBatch]: |
163 | 164 |
|
164 | 165 | seq_lens = self._seq_lens.copy() |
165 | 166 |
|
166 | | - seq_lens[-1] -= 1 |
| 167 | + if seq_lens[-1] == 1: |
| 168 | + if len(seq_lens) == 1: |
| 169 | + raise InvalidOperationError( |
| 170 | + "The length of the sequence at index 0 is already 1 and cannot be trimmed to 0." |
| 171 | + ) |
| 172 | + |
| 173 | + del seq_lens[-1] |
| 174 | + else: |
| 175 | + seq_lens[-1] -= 1 |
167 | 176 | else: |
168 | 177 | seqs = self._seqs[:, :-1] |
169 | 178 |
|
170 | | - seq_lens = [seq_len - 1 for seq_len in self._seq_lens] |
| 179 | + seq_lens = [] |
| 180 | + |
| 181 | + for idx, seq_len in enumerate(self._seq_lens): |
| 182 | + if seq_len == 1: |
| 183 | + raise InvalidOperationError( |
| 184 | + f"The length of the sequence at index {idx} is already 1 and cannot be trimmed to 0." |
| 185 | + ) |
| 186 | + |
| 187 | + seq_lens.append(seq_len - 1) |
171 | 188 |
|
172 | 189 | batch = SequenceBatch( |
173 | 190 | seqs, seq_lens, packed=self._packed, example=self._example |
@@ -487,11 +504,27 @@ def as_auto_regressive(self) -> tuple[Seq2SeqBatch, SequenceBatch]: |
487 | 504 |
|
488 | 505 | seq_lens = self._target_seq_lens.copy() |
489 | 506 |
|
490 | | - seq_lens[-1] -= 1 |
| 507 | + if seq_lens[-1] == 1: |
| 508 | + if len(seq_lens) == 1: |
| 509 | + raise InvalidOperationError( |
| 510 | + "The length of the target sequence at index 0 is already 1 and cannot be trimmed to 0." |
| 511 | + ) |
| 512 | + |
| 513 | + del seq_lens[-1] |
| 514 | + else: |
| 515 | + seq_lens[-1] -= 1 |
491 | 516 | else: |
492 | 517 | seqs = self._target_seqs[:, :-1] |
493 | 518 |
|
494 | | - seq_lens = [seq_len - 1 for seq_len in self._target_seq_lens] |
| 519 | + seq_lens = [] |
| 520 | + |
| 521 | + for idx, seq_len in enumerate(self._target_seq_lens): |
| 522 | + if seq_len == 1: |
| 523 | + raise InvalidOperationError( |
| 524 | + f"The length of the target sequence at index {idx} is already 1 and cannot be trimmed to 0." |
| 525 | + ) |
| 526 | + |
| 527 | + seq_lens.append(seq_len - 1) |
495 | 528 |
|
496 | 529 | batch = Seq2SeqBatch( |
497 | 530 | self._source_seqs, |
|
0 commit comments