File tree Expand file tree Collapse file tree 2 files changed +7
-12
lines changed Expand file tree Collapse file tree 2 files changed +7
-12
lines changed Original file line number Diff line number Diff line change @@ -64,7 +64,6 @@ import {
6464
6565import {
6666 Tensor ,
67- cat
6867} from './utils/tensor.js' ;
6968
7069import { executionProviders , ONNX } from './backends/onnx.js' ;
@@ -839,16 +838,9 @@ export class PreTrainedModel extends Callable {
839838 // In most cases, this will be [batch_size, 1, vocab_size]
840839 // So, we select the last token's logits:
841840 // (equivalent to `logits = outputs.logits[:, -1, :]`)
842- let extractedLogits = [ ] ;
843- for ( const batch of output . logits ) {
844- // Extract logits corresponding to the last token
845- let lastLogits = batch [ - 1 ] ;
846-
847- // Add back batch dimension (needed for `cat`)
848- lastLogits . dims = [ 1 , ...lastLogits . dims ] ;
849- extractedLogits . push ( lastLogits )
850- }
851- let logits = cat ( extractedLogits ) ;
841+ let logits = output . logits . slice ( null , - 1 , null ) ;
842+
843+ // Apply logits processor
852844 logits_processor ( beam . output_token_ids , logits ) ;
853845
854846 let sampledTokens = sampler ( logits ) ;
Original file line number Diff line number Diff line change @@ -185,9 +185,12 @@ export class Tensor extends ONNXTensor {
185185 newTensorDims . push ( this . dims [ sliceIndex ] ) ;
186186
187187 } else if ( typeof slice === 'number' ) {
188- if ( slice < 0 || slice >= this . dims [ sliceIndex ] ) {
188+ if ( slice < - this . dims [ sliceIndex ] || slice >= this . dims [ sliceIndex ] ) {
189189 throw new Error ( `IndexError: index ${ slice } is out of bounds for dimension ${ sliceIndex } with size ${ this . dims [ sliceIndex ] } ` ) ;
190190 }
191+ if ( slice < 0 ) {
192+ slice += this . dims [ sliceIndex ] ;
193+ }
191194
192195 // A number means take a single element
193196 newOffsets . push ( [ slice , slice + 1 ] ) ;
You can’t perform that action at this time.
0 commit comments