Skip to content

Commit d79c055

Browse files
committed
Replace CommonJS imports/exports with ES6
1 parent 5a0a40f commit d79c055

File tree

17 files changed

+163
-280
lines changed

17 files changed

+163
-280
lines changed

jsconfig.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
{
22
"compilerOptions": {
3-
"checkJs": true
3+
"checkJs": true,
4+
"target": "esnext",
5+
"module": "esnext",
6+
"moduleResolution": "nodenext",
47
}
58
}

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"version": "1.4.2",
44
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, MobileBERT, SqueezeBERT, T5, T5v1.1, FLAN-T5, mT5, BART, MarianMT, GPT2, GPT Neo, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, token classification, zero-shot classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, image segmentation, and object detection.",
55
"main": "./src/transformers.js",
6+
"type": "module",
67
"directories": {
78
"test": "tests"
89
},

src/backends/onnx.js

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
let ONNX;
1+
export let ONNX;
22

33
// TODO support more execution providers (e.g., webgpu)
4-
const executionProviders = ['wasm'];
4+
export const executionProviders = ['wasm'];
55

66
if (typeof process !== 'undefined') {
77
// Running in a node-like environment.
88
// Try to import onnxruntime-node, using onnxruntime-web as a fallback
99
try {
10-
ONNX = require('onnxruntime-node');
10+
ONNX = (await import('onnxruntime-node')).default;
1111
} catch (err) {
1212
console.warn(
1313
"Node.js environment detected, but `onnxruntime-node` was not found. " +
@@ -20,7 +20,7 @@ if (typeof process !== 'undefined') {
2020
// @ts-ignore
2121
global.self = global;
2222

23-
ONNX = require('onnxruntime-web');
23+
ONNX = (await import('onnxruntime-web')).default;
2424

2525
// Disable spawning worker threads for testing.
2626
// This is done by setting numThreads to 1
@@ -33,10 +33,6 @@ if (typeof process !== 'undefined') {
3333

3434
} else {
3535
// Running in a browser-environment, so we just import `onnxruntime-web`
36-
ONNX = require('onnxruntime-web');
36+
ONNX = (await import('onnxruntime-web')).default;
3737
}
3838

39-
module.exports = {
40-
ONNX,
41-
executionProviders,
42-
}

src/env.js

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
const fs = require('fs');
2-
const path = require('path');
1+
import fs from 'fs';
2+
import path from 'path';
3+
import { fileURLToPath } from 'url';
34

4-
const { env: onnx_env } = require('./backends/onnx.js').ONNX;
5+
import { ONNX } from './backends/onnx.js';
6+
const { env: onnx_env } = ONNX;
7+
8+
const __dirname = path.dirname(path.dirname(fileURLToPath(import.meta.url)));
59

610
// check if various APIs are available (depends on environment)
711
const CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self;
@@ -13,20 +17,20 @@ const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;
1317
// set local model path, based on available APIs
1418
const DEFAULT_LOCAL_PATH = '/models/onnx/quantized/';
1519
const localURL = RUNNING_LOCALLY
16-
? path.join(path.dirname(__dirname), DEFAULT_LOCAL_PATH)
20+
? path.join(__dirname, DEFAULT_LOCAL_PATH)
1721
: DEFAULT_LOCAL_PATH;
1822

1923
// First, set path to wasm files. This is needed when running in a web worker.
2024
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
2125
// We use remote wasm files by default to make it easier for newer users.
2226
// In practice, users should probably self-host the necessary .wasm files.
2327
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
24-
? path.join(path.dirname(__dirname), '/dist/')
28+
? path.join(__dirname, '/dist/')
2529
: 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/';
2630

2731

2832
// Global variable used to control exection, with suitable defaults
29-
const env = {
33+
export const env = {
3034
// access onnxruntime-web's environment variables
3135
onnx: onnx_env,
3236

@@ -44,6 +48,9 @@ const env = {
4448

4549
// Whether to use the file system to load files. By default, it is true available.
4650
useFS: FS_AVAILABLE,
51+
52+
// Directory name of module. Useful for resolving local paths.
53+
__dirname,
4754
}
4855

4956

@@ -54,6 +61,3 @@ function isEmpty(obj) {
5461
return Object.keys(obj).length === 0;
5562
}
5663

57-
module.exports = {
58-
env
59-
}

src/fft.js

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* FFT class provides functionality for performing Fast Fourier Transform on arrays
44
* Code adapted from https://www.npmjs.com/package/fft.js
55
*/
6-
class FFT {
6+
export default class FFT {
77
/**
88
* @param {number} size - The size of the input array. Must be a power of two and bigger than 1.
99
* @throws {Error} FFT size must be a power of two and bigger than 1.
@@ -494,5 +494,3 @@ class FFT {
494494
out[outOff + 7] = T3r;
495495
}
496496
}
497-
498-
module.exports = FFT

src/generation.js

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
const { Tensor } = require("./tensor_utils.js");
2-
const {
1+
import { Tensor } from './tensor_utils.js';
2+
import {
33
Callable,
44
exists,
55
log_softmax
6-
} = require("./utils.js");
6+
} from './utils.js';
77

88
/**
99
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
@@ -12,7 +12,7 @@ const {
1212
*
1313
* @extends Callable
1414
*/
15-
class LogitsProcessorList extends Callable {
15+
export class LogitsProcessorList extends Callable {
1616
/**
1717
* Constructs a new instance of `LogitsProcessorList`.
1818
*/
@@ -66,7 +66,7 @@ class LogitsProcessorList extends Callable {
6666
* Base class for processing logits.
6767
* @extends Callable
6868
*/
69-
class LogitsProcessor extends Callable {
69+
export class LogitsProcessor extends Callable {
7070
/**
7171
* Apply the processor to the input logits.
7272
*
@@ -85,7 +85,7 @@ class LogitsProcessor extends Callable {
8585
*
8686
* @extends LogitsProcessor
8787
*/
88-
class ForceTokensLogitsProcessor extends LogitsProcessor {
88+
export class ForceTokensLogitsProcessor extends LogitsProcessor {
8989
/**
9090
* Constructs a new instance of `ForceTokensLogitsProcessor`.
9191
*
@@ -117,7 +117,7 @@ class ForceTokensLogitsProcessor extends LogitsProcessor {
117117
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
118118
* @extends LogitsProcessor
119119
*/
120-
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
120+
export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
121121
/**
122122
* Create a ForcedBOSTokenLogitsProcessor.
123123
* @param {number} bos_token_id - The ID of the beginning-of-sequence token to be forced.
@@ -146,7 +146,7 @@ class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
146146
*
147147
* @extends LogitsProcessor
148148
*/
149-
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
149+
export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
150150
/**
151151
* Create a ForcedEOSTokenLogitsProcessor.
152152
* @param {number} max_length - Max length of the sequence.
@@ -174,7 +174,7 @@ class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
174174
* A LogitsProcessor that handles adding timestamps to generated text.
175175
* @extends LogitsProcessor
176176
*/
177-
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
177+
export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
178178
/**
179179
* Constructs a new WhisperTimeStampLogitsProcessor.
180180
* @param {object} generate_config - The config object passed to the `generate()` method of a transformer model.
@@ -249,7 +249,7 @@ class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
249249
*
250250
* @extends LogitsProcessor
251251
*/
252-
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
252+
export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
253253
/**
254254
* Create a NoRepeatNGramLogitsProcessor.
255255
* @param {number} no_repeat_ngram_size - The no-repeat-ngram size. All ngrams of this size can only occur once.
@@ -340,7 +340,7 @@ class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
340340
*
341341
* @extends LogitsProcessor
342342
*/
343-
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
343+
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
344344
/**
345345
* Create a RepetitionPenaltyLogitsProcessor.
346346
* @param {number} penalty - The penalty to apply for repeated tokens.
@@ -372,7 +372,7 @@ class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
372372
}
373373

374374

375-
class GenerationConfig {
375+
export class GenerationConfig {
376376
constructor(kwargs = {}) {
377377
// Parameters that control the length of the output
378378
// TODO: extend the configuration with correct types
@@ -465,15 +465,3 @@ class GenerationConfig {
465465
this.generation_kwargs = kwargs.generation_kwargs ?? {};
466466
}
467467
}
468-
469-
module.exports = {
470-
LogitsProcessor,
471-
LogitsProcessorList,
472-
GenerationConfig,
473-
ForcedBOSTokenLogitsProcessor,
474-
ForcedEOSTokenLogitsProcessor,
475-
WhisperTimeStampLogitsProcessor,
476-
ForceTokensLogitsProcessor,
477-
NoRepeatNGramLogitsProcessor,
478-
RepetitionPenaltyLogitsProcessor
479-
};

src/image_utils.js

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

2-
const fs = require('fs');
3-
const { getFile, isString } = require('./utils.js');
4-
const { env } = require('./env.js');
2+
import fs from 'fs';
3+
import { getFile, isString } from './utils.js';
4+
import { env } from './env.js';
55

66
let CanvasClass;
7-
let ImageClass = typeof Image !== 'undefined' ? Image : null; // Only used for type-checking
7+
let ImageClass = typeof Image !== 'undefined' ? Image : null;
88

99
let ImageDataClass;
1010
let loadImageFunction;
@@ -14,15 +14,15 @@ if (typeof self !== 'undefined') {
1414
ImageDataClass = ImageData;
1515

1616
} else {
17-
const { Canvas, loadImage, ImageData, Image } = require('canvas');
17+
const { Canvas, loadImage, ImageData, Image } = await import('canvas');
1818
CanvasClass = Canvas;
1919
loadImageFunction = async (/**@type {Blob}*/ b) => await loadImage(Buffer.from(await b.arrayBuffer()));
2020
ImageDataClass = ImageData;
2121
ImageClass = Image;
2222
}
2323

2424

25-
class CustomImage {
25+
export class CustomImage {
2626

2727
/**
2828
* Create a new CustomImage object.
@@ -277,7 +277,3 @@ class CustomImage {
277277
fs.writeFileSync(path, buffer);
278278
}
279279
}
280-
281-
module.exports = {
282-
CustomImage,
283-
};

src/math_utils.js

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
/**
99
* @param {TypedArray} input
1010
*/
11-
function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
11+
export function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
1212
// TODO use mode and align_corners
1313

1414
// Output image dimensions
@@ -86,7 +86,7 @@ function interpolate(input, [in_channels, in_height, in_width], [out_height, out
8686
* @param {number[]} axes
8787
* @returns {[T, number[]]} The transposed array and the new shape.
8888
*/
89-
function transpose_data(array, dims, axes) {
89+
export function transpose_data(array, dims, axes) {
9090
// Calculate the new shape of the transposed array
9191
// and the stride of the original array
9292
const shape = new Array(axes.length);
@@ -117,8 +117,3 @@ function transpose_data(array, dims, axes) {
117117

118118
return [transposedData, shape];
119119
}
120-
121-
module.exports = {
122-
interpolate,
123-
transpose: transpose_data,
124-
}

0 commit comments

Comments
 (0)