Skip to content

Commit f25b17c

Browse files
committed
Merge branch 'es6-rewrite' into docs
2 parents a0b53f6 + 421aa33 commit f25b17c

File tree

16 files changed

+167
-280
lines changed

16 files changed

+167
-280
lines changed

jsconfig.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
],
66
"compilerOptions": {
77
// Tells the compiler to check JS files
8-
"checkJs": true
8+
"checkJs": true,
9+
"target": "esnext",
10+
"module": "esnext",
11+
"moduleResolution": "nodenext",
912
}
1013
}

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, MobileBERT, SqueezeBERT, T5, T5v1.1, FLAN-T5, mT5, BART, MarianMT, GPT2, GPT Neo, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, token classification, zero-shot classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, image segmentation, and object detection.",
55
"main": "./src/transformers.js",
66
"types": "./types/transformers.d.ts",
7+
"type": "module",
78
"directories": {
89
"test": "tests"
910
},

src/backends/onnx.js

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
let ONNX;
1+
export let ONNX;
22

33
// TODO support more execution providers (e.g., webgpu)
4-
const executionProviders = ['wasm'];
4+
export const executionProviders = ['wasm'];
55

66
if (typeof process !== 'undefined') {
77
// Running in a node-like environment.
88
// Try to import onnxruntime-node, using onnxruntime-web as a fallback
99
try {
10-
ONNX = require('onnxruntime-node');
10+
ONNX = (await import('onnxruntime-node')).default;
1111
} catch (err) {
1212
console.warn(
1313
"Node.js environment detected, but `onnxruntime-node` was not found. " +
@@ -20,7 +20,7 @@ if (typeof process !== 'undefined') {
2020
// @ts-ignore
2121
global.self = global;
2222

23-
ONNX = require('onnxruntime-web');
23+
ONNX = (await import('onnxruntime-web')).default;
2424

2525
// Disable spawning worker threads for testing.
2626
// This is done by setting numThreads to 1
@@ -33,10 +33,6 @@ if (typeof process !== 'undefined') {
3333

3434
} else {
3535
// Running in a browser-environment, so we just import `onnxruntime-web`
36-
ONNX = require('onnxruntime-web');
36+
ONNX = (await import('onnxruntime-web')).default;
3737
}
3838

39-
module.exports = {
40-
ONNX,
41-
executionProviders,
42-
}

src/env.js

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
const fs = require('fs');
2-
const path = require('path');
1+
import fs from 'fs';
2+
import path from 'path';
3+
import { fileURLToPath } from 'url';
34

4-
const { env: onnx_env } = require('./backends/onnx.js').ONNX;
5+
import { ONNX } from './backends/onnx.js';
6+
const { env: onnx_env } = ONNX;
7+
8+
const __dirname = path.dirname(path.dirname(fileURLToPath(import.meta.url)));
59

610

711

@@ -28,7 +32,7 @@ const localModelPath = RUNNING_LOCALLY
2832
// We use remote wasm files by default to make it easier for newer users.
2933
// In practice, users should probably self-host the necessary .wasm files.
3034
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
31-
? path.join(path.dirname(__dirname), '/dist/')
35+
? path.join(__dirname, '/dist/')
3236
: 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/';
3337

3438

@@ -55,7 +59,9 @@ const env = {
5559

5660
// Whether to use the file system to load files. By default, it is true available.
5761
useFS: FS_AVAILABLE,
58-
62+
63+
// Directory name of module. Useful for resolving local paths.
64+
__dirname,
5965

6066
/////////////////// Cache settings ///////////////////
6167
// Whether to use Cache API to cache models. By default, it is true if available.
@@ -78,6 +84,3 @@ function isEmpty(obj) {
7884
return Object.keys(obj).length === 0;
7985
}
8086

81-
module.exports = {
82-
env
83-
}

src/generation.js

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
const { Tensor } = require("./tensor_utils.js");
2-
const {
1+
import { Tensor } from './tensor_utils.js';
2+
import {
33
Callable,
44
exists,
5-
} = require("./utils.js");
6-
const {
5+
} from './utils.js';
6+
import {
77
log_softmax
8-
} = require('./math_utils.js');
8+
} from './math_utils.js';
99

1010
/**
1111
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
@@ -14,7 +14,7 @@ const {
1414
*
1515
* @extends Callable
1616
*/
17-
class LogitsProcessorList extends Callable {
17+
export class LogitsProcessorList extends Callable {
1818
/**
1919
* Constructs a new instance of `LogitsProcessorList`.
2020
*/
@@ -68,7 +68,7 @@ class LogitsProcessorList extends Callable {
6868
* Base class for processing logits.
6969
* @extends Callable
7070
*/
71-
class LogitsProcessor extends Callable {
71+
export class LogitsProcessor extends Callable {
7272
/**
7373
* Apply the processor to the input logits.
7474
*
@@ -87,7 +87,7 @@ class LogitsProcessor extends Callable {
8787
*
8888
* @extends LogitsProcessor
8989
*/
90-
class ForceTokensLogitsProcessor extends LogitsProcessor {
90+
export class ForceTokensLogitsProcessor extends LogitsProcessor {
9191
/**
9292
* Constructs a new instance of `ForceTokensLogitsProcessor`.
9393
*
@@ -119,7 +119,7 @@ class ForceTokensLogitsProcessor extends LogitsProcessor {
119119
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
120120
* @extends LogitsProcessor
121121
*/
122-
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
122+
export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
123123
/**
124124
* Create a ForcedBOSTokenLogitsProcessor.
125125
* @param {number} bos_token_id - The ID of the beginning-of-sequence token to be forced.
@@ -148,7 +148,7 @@ class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
148148
*
149149
* @extends LogitsProcessor
150150
*/
151-
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
151+
export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
152152
/**
153153
* Create a ForcedEOSTokenLogitsProcessor.
154154
* @param {number} max_length - Max length of the sequence.
@@ -176,7 +176,7 @@ class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
176176
* A LogitsProcessor that handles adding timestamps to generated text.
177177
* @extends LogitsProcessor
178178
*/
179-
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
179+
export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
180180
/**
181181
* Constructs a new WhisperTimeStampLogitsProcessor.
182182
* @param {object} generate_config - The config object passed to the `generate()` method of a transformer model.
@@ -251,7 +251,7 @@ class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
251251
*
252252
* @extends LogitsProcessor
253253
*/
254-
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
254+
export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
255255
/**
256256
* Create a NoRepeatNGramLogitsProcessor.
257257
* @param {number} no_repeat_ngram_size - The no-repeat-ngram size. All ngrams of this size can only occur once.
@@ -342,7 +342,7 @@ class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
342342
*
343343
* @extends LogitsProcessor
344344
*/
345-
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
345+
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
346346
/**
347347
* Create a RepetitionPenaltyLogitsProcessor.
348348
* @param {number} penalty - The penalty to apply for repeated tokens.
@@ -374,7 +374,7 @@ class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
374374
}
375375

376376

377-
class GenerationConfig {
377+
export class GenerationConfig {
378378
/**
379379
* Create a GenerationConfig object
380380
* @constructor
@@ -495,15 +495,3 @@ class GenerationConfig {
495495
this.generation_kwargs = kwargs.generation_kwargs ?? {};
496496
}
497497
}
498-
499-
module.exports = {
500-
LogitsProcessor,
501-
LogitsProcessorList,
502-
GenerationConfig,
503-
ForcedBOSTokenLogitsProcessor,
504-
ForcedEOSTokenLogitsProcessor,
505-
WhisperTimeStampLogitsProcessor,
506-
ForceTokensLogitsProcessor,
507-
NoRepeatNGramLogitsProcessor,
508-
RepetitionPenaltyLogitsProcessor
509-
};

src/image_utils.js

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11

2-
const fs = require('fs');
3-
const { isString } = require('./utils.js');
4-
const { env } = require('./env.js');
2+
import fs from 'fs';
3+
import { isString } from './utils.js';
4+
import { env } from './env.js';
55

6-
const { getFile } = require('./utils/hub.js');
6+
import { getFile } from './utils/hub.js';
77

88
// Will be empty (or not used) if running in browser or web-worker
9-
const sharp = require('sharp');
9+
import sharp from 'sharp';
1010

1111
let CanvasClass;
1212
let ImageDataClass;
@@ -30,7 +30,7 @@ if (typeof self !== 'undefined') {
3030
}
3131

3232

33-
class CustomImage {
33+
export class CustomImage {
3434

3535
/**
3636
* Create a new CustomImage object.
@@ -440,7 +440,3 @@ class CustomImage {
440440
fs.writeFileSync(path, buffer);
441441
}
442442
}
443-
444-
module.exports = {
445-
CustomImage,
446-
};

src/math_utils.js

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
/**
99
* @param {TypedArray} input
1010
*/
11-
function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
11+
export function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
1212
// TODO use mode and align_corners
1313

1414
// Output image dimensions
@@ -86,7 +86,7 @@ function interpolate(input, [in_channels, in_height, in_width], [out_height, out
8686
* @param {number[]} axes
8787
* @returns {[T, number[]]} The transposed array and the new shape.
8888
*/
89-
function transpose_data(array, dims, axes) {
89+
export function transpose_data(array, dims, axes) {
9090
// Calculate the new shape of the transposed array
9191
// and the stride of the original array
9292
const shape = new Array(axes.length);
@@ -125,7 +125,7 @@ function transpose_data(array, dims, axes) {
125125
* @param {number[]} arr - The array of numbers to compute the softmax of.
126126
* @returns {number[]} The softmax array.
127127
*/
128-
function softmax(arr) {
128+
export function softmax(arr) {
129129
// Compute the maximum value in the array
130130
const maxVal = max(arr)[0];
131131

@@ -146,7 +146,7 @@ function softmax(arr) {
146146
* @param {number[]} arr - The input array to calculate the log_softmax function for.
147147
* @returns {any} - The resulting log_softmax array.
148148
*/
149-
function log_softmax(arr) {
149+
export function log_softmax(arr) {
150150
// Compute the softmax values
151151
const softmaxArr = softmax(arr);
152152

@@ -162,7 +162,7 @@ function log_softmax(arr) {
162162
* @param {number[]} arr2 - The second array.
163163
* @returns {number} - The dot product of arr1 and arr2.
164164
*/
165-
function dot(arr1, arr2) {
165+
export function dot(arr1, arr2) {
166166
return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0);
167167
}
168168

@@ -174,7 +174,7 @@ function dot(arr1, arr2) {
174174
* @param {number} [top_k=0] - The number of top items to return (default: 0 = return all)
175175
* @returns {Array} - The top k items, sorted by descending order
176176
*/
177-
function getTopItems(items, top_k = 0) {
177+
export function getTopItems(items, top_k = 0) {
178178
// if top == 0, return all
179179

180180
items = Array.from(items)
@@ -195,7 +195,7 @@ function getTopItems(items, top_k = 0) {
195195
* @param {number[]} arr2 - The second array.
196196
* @returns {number} The cosine similarity between the two arrays.
197197
*/
198-
function cos_sim(arr1, arr2) {
198+
export function cos_sim(arr1, arr2) {
199199
// Calculate dot product of the two arrays
200200
const dotProduct = dot(arr1, arr2);
201201

@@ -216,7 +216,7 @@ function cos_sim(arr1, arr2) {
216216
* @param {number[]} arr - The array to calculate the magnitude of.
217217
* @returns {number} The magnitude of the array.
218218
*/
219-
function magnitude(arr) {
219+
export function magnitude(arr) {
220220
return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0));
221221
}
222222

@@ -227,7 +227,7 @@ function magnitude(arr) {
227227
* @returns {number[]} - the value and index of the minimum element, of the form: [valueOfMin, indexOfMin]
228228
* @throws {Error} If array is empty.
229229
*/
230-
function min(arr) {
230+
export function min(arr) {
231231
if (arr.length === 0) throw Error('Array must not be empty');
232232
let min = arr[0];
233233
let indexOfMin = 0;
@@ -247,7 +247,7 @@ function min(arr) {
247247
* @returns {number[]} - the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
248248
* @throws {Error} If array is empty.
249249
*/
250-
function max(arr) {
250+
export function max(arr) {
251251
if (arr.length === 0) throw Error('Array must not be empty');
252252
let max = arr[0];
253253
let indexOfMax = 0;
@@ -265,7 +265,7 @@ function max(arr) {
265265
* FFT class provides functionality for performing Fast Fourier Transform on arrays
266266
* Code adapted from https://www.npmjs.com/package/fft.js
267267
*/
268-
class FFT {
268+
export class FFT {
269269
/**
270270
* @param {number} size - The size of the input array. Must be a power of two and bigger than 1.
271271
* @throws {Error} FFT size must be a power of two and bigger than 1.
@@ -757,16 +757,3 @@ class FFT {
757757
}
758758
}
759759

760-
module.exports = {
761-
interpolate,
762-
transpose: transpose_data,
763-
softmax,
764-
log_softmax,
765-
getTopItems,
766-
dot,
767-
cos_sim,
768-
magnitude,
769-
min,
770-
max,
771-
FFT
772-
}

0 commit comments

Comments
 (0)