Skip to content

Commit 76c6377

Browse files
committed
Allow negative indices for slicing
1 parent 88be559 commit 76c6377

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

src/models.js

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ import {
6464

6565
import {
6666
Tensor,
67-
cat
6867
} from './utils/tensor.js';
6968

7069
import { executionProviders, ONNX } from './backends/onnx.js';
@@ -839,16 +838,9 @@ export class PreTrainedModel extends Callable {
839838
// In most cases, this will be [batch_size, 1, vocab_size]
840839
// So, we select the last token's logits:
841840
// (equivalent to `logits = outputs.logits[:, -1, :]`)
842-
let extractedLogits = [];
843-
for (const batch of output.logits) {
844-
// Extract logits corresponding to the last token
845-
let lastLogits = batch[-1];
846-
847-
// Add back batch dimension (needed for `cat`)
848-
lastLogits.dims = [1, ...lastLogits.dims];
849-
extractedLogits.push(lastLogits)
850-
}
851-
let logits = cat(extractedLogits);
841+
let logits = output.logits.slice(null, -1, null);
842+
843+
// Apply logits processor
852844
logits_processor(beam.output_token_ids, logits);
853845

854846
let sampledTokens = sampler(logits);

src/utils/tensor.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ export class Tensor extends ONNXTensor {
185185
newTensorDims.push(this.dims[sliceIndex]);
186186

187187
} else if (typeof slice === 'number') {
188-
if (slice < 0 || slice >= this.dims[sliceIndex]) {
188+
if (slice < -this.dims[sliceIndex] || slice >= this.dims[sliceIndex]) {
189189
throw new Error(`IndexError: index ${slice} is out of bounds for dimension ${sliceIndex} with size ${this.dims[sliceIndex]}`);
190190
}
191+
if (slice < 0) {
192+
slice += this.dims[sliceIndex];
193+
}
191194

192195
// A number means take a single element
193196
newOffsets.push([slice, slice + 1]);

0 commit comments

Comments
 (0)