Skip to content

Commit e48d6eb

Browse files
committed
Fix return types of logits processors
1 parent c38627e commit e48d6eb

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/generation/logits_process.js

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
151151
* Apply the BOS token forcing to the logits.
152152
* @param {bigint[][]} input_ids The input IDs.
153153
* @param {Tensor} logits The logits.
154-
* @returns {Object} The logits with BOS token forcing.
154+
* @returns {Tensor} The logits with BOS token forcing.
155155
*/
156156
_call(input_ids, logits) {
157157
for (let i = 0; i < input_ids.length; ++i) {
@@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
221221
* Apply the BOS token forcing to the logits.
222222
* @param {bigint[][]} input_ids The input IDs.
223223
* @param {Tensor} logits The logits.
224-
* @returns {Object} The logits with BOS token forcing.
224+
* @returns {Tensor} The logits with BOS token forcing.
225225
*/
226226
_call(input_ids, logits) {
227227
for (let i = 0; i < input_ids.length; ++i) {
@@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
391391
* Apply the no-repeat-ngram processor to the logits.
392392
* @param {bigint[][]} input_ids The input IDs.
393393
* @param {Tensor} logits The logits.
394-
* @returns {Object} The logits with no-repeat-ngram processing.
394+
* @returns {Tensor} The logits with no-repeat-ngram processing.
395395
*/
396396
_call(input_ids, logits) {
397397
for (let i = 0; i < input_ids.length; ++i) {
@@ -432,7 +432,7 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
432432
* Apply the repetition penalty to the logits.
433433
* @param {bigint[][]} input_ids The input IDs.
434434
* @param {Tensor} logits The logits.
435-
* @returns {Object} The logits with repetition penalty processing.
435+
* @returns {Tensor} The logits with repetition penalty processing.
436436
*/
437437
_call(input_ids, logits) {
438438
for (let i = 0; i < input_ids.length; ++i) {
@@ -470,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
470470
* Apply logit processor.
471471
* @param {bigint[][]} input_ids The input IDs.
472472
* @param {Tensor} logits The logits.
473-
* @returns {Object} The processed logits.
473+
* @returns {Tensor} The processed logits.
474474
*/
475475
_call(input_ids, logits) {
476476
for (let i = 0; i < input_ids.length; ++i) {
@@ -508,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
508508
* Apply logit processor.
509509
* @param {bigint[][]} input_ids The input IDs.
510510
* @param {Tensor} logits The logits.
511-
* @returns {Object} The processed logits.
511+
* @returns {Tensor} The processed logits.
512512
*/
513513
_call(input_ids, logits) {
514514
for (let i = 0; i < input_ids.length; ++i) {
@@ -541,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
541541
* Apply logit processor.
542542
* @param {bigint[][]} input_ids The input IDs.
543543
* @param {Tensor} logits The logits.
544-
* @returns {Object} The processed logits.
544+
* @returns {Tensor} The processed logits.
545545
*/
546546
_call(input_ids, logits) {
547547
for (let i = 0; i < input_ids.length; ++i) {
@@ -602,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
602602
* Apply logit processor.
603603
* @param {bigint[][]} input_ids The input IDs.
604604
* @param {Tensor} logits The logits.
605-
* @returns {Object} The processed logits.
605+
* @returns {Tensor} The processed logits.
606606
*/
607607
_call(input_ids, logits) {
608608
if (logits.dims[0] !== 2 * input_ids.length) {
@@ -656,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
656656
* Apply logit warper.
657657
* @param {bigint[][]} input_ids The input IDs.
658658
* @param {Tensor} logits The logits.
659-
* @returns {Object} The processed logits.
659+
* @returns {Tensor} The processed logits.
660660
*/
661661
_call(input_ids, logits) {
662662
const batch_logits_data = /** @type {Float32Array} */(logits.data);

0 commit comments

Comments
 (0)