diff --git a/src/backends/onnx.js b/src/backends/onnx.ts similarity index 89% rename from src/backends/onnx.js rename to src/backends/onnx.ts index 38cd71337..5900c5380 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.ts @@ -16,21 +16,24 @@ * @module backends/onnx */ -import { env, apis } from '../env.js'; +import { env, apis } from '../env'; // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; +import { DeviceType } from '../utils/devices'; +import { InferenceSession as ONNXInferenceSession } from 'onnxruntime-common'; export { Tensor } from 'onnxruntime-common'; /** * @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders */ +type ONNXExecutionProviders = ONNXInferenceSession.ExecutionProviderConfig; /** @type {Record} */ -const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({ +const DEVICE_TO_EXECUTION_PROVIDER_MAPPING: Record = Object.freeze({ auto: null, // Auto-detect based on device and environment gpu: null, // Auto-detect GPU cpu: 'cpu', // CPU @@ -49,10 +52,10 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({ * The list of supported devices, sorted by priority/performance. * @type {import("../utils/devices.js").DeviceType[]} */ -const supportedDevices = []; +const supportedDevices: DeviceType[] = []; /** @type {ONNXExecutionProviders[]} */ -let defaultDevices; +let defaultDevices: ONNXExecutionProviders[]; let ONNX; const ORT_SYMBOL = Symbol.for('onnxruntime'); @@ -61,7 +64,7 @@ if (ORT_SYMBOL in globalThis) { ONNX = globalThis[ORT_SYMBOL]; } else if (apis.IS_NODE_ENV) { - ONNX = ONNX_NODE.default ?? ONNX_NODE; + ONNX = ONNX_NODE; // Updated as of ONNX Runtime 1.20.1 // The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries. @@ -109,7 +112,7 @@ const InferenceSession = ONNX.InferenceSession; * @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on. * @returns {ONNXExecutionProviders[]} The execution providers to use for the given device. */ -export function deviceToExecutionProviders(device = null) { +export function deviceToExecutionProviders(device: DeviceType | "auto" | null = null): ONNXExecutionProviders[] { // Use the default execution providers if the user hasn't specified anything if (!device) return defaultDevices; @@ -137,7 +140,7 @@ export function deviceToExecutionProviders(device = null) { * will wait for this Promise to resolve before creating their own InferenceSession. * @type {Promise|null} */ -let wasmInitPromise = null; +let wasmInitPromise: Promise | null = null; /** * Create an ONNX inference session. @@ -146,7 +149,7 @@ let wasmInitPromise = null; * @param {Object} session_config ONNX inference session configuration. * @returns {Promise} The ONNX inference session. */ -export async function createInferenceSession(buffer, session_options, session_config) { +export async function createInferenceSession(buffer: Uint8Array, session_options: ONNXInferenceSession.SessionOptions, session_config: Object): Promise { if (wasmInitPromise) { // A previous session has already initialized the WASM runtime // so we wait for it to resolve before creating this new session. @@ -165,13 +168,13 @@ export async function createInferenceSession(buffer, session_options, session_co * @param {any} x The object to check * @returns {boolean} Whether the object is an ONNX tensor. */ -export function isONNXTensor(x) { +export function isONNXTensor(x: any): boolean { return x instanceof ONNX.Tensor; } /** @type {import('onnxruntime-common').Env} */ // @ts-ignore -const ONNX_ENV = ONNX?.env; +const ONNX_ENV: Env = ONNX?.env; if (ONNX_ENV?.wasm) { // Initialize wasm backend with suitable default settings. @@ -202,7 +205,7 @@ if (ONNX_ENV?.webgpu) { * Check if ONNX's WASM backend is being proxied. * @returns {boolean} Whether ONNX's WASM backend is being proxied. */ -export function isONNXProxy() { +export function isONNXProxy(): boolean { // TODO: Update this when allowing non-WASM backends. return ONNX_ENV?.wasm?.proxy; } diff --git a/src/base/feature_extraction_utils.js b/src/base/feature_extraction_utils.ts similarity index 80% rename from src/base/feature_extraction_utils.js rename to src/base/feature_extraction_utils.ts index 300f5ea1d..85a2d4475 100644 --- a/src/base/feature_extraction_utils.js +++ b/src/base/feature_extraction_utils.ts @@ -1,17 +1,18 @@ import { FEATURE_EXTRACTOR_NAME } from "../utils/constants.js"; import { Callable } from "../utils/generic.js"; -import { getModelJSON } from "../utils/hub.js"; +import { getModelJSON, PretrainedOptions } from "../utils/hub.js"; /** * Base class for feature extractors. */ export class FeatureExtractor extends Callable { + config: Object; /** * Constructs a new FeatureExtractor instance. * * @param {Object} config The configuration for the feature extractor. */ - constructor(config) { + constructor(config: Object) { super(); this.config = config } @@ -27,11 +28,11 @@ export class FeatureExtractor extends Callable { * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing feature_extractor files, e.g., `./my_model_directory/`. - * @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the feature_extractor. + * @param {import('../utils/hub').PretrainedOptions} options Additional options for loading the feature_extractor. * * @returns {Promise} A new instance of the Feature Extractor class. */ - static async from_pretrained(pretrained_model_name_or_path, options) { + static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedOptions): Promise { const config = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options); return new this(config); } @@ -44,9 +45,10 @@ export class FeatureExtractor extends Callable { * @param {string} feature_extractor The name of the feature extractor. * @private */ -export function validate_audio_inputs(audio, feature_extractor) { +export function validate_audio_inputs(audio: Float32Array | Float64Array, feature_extractor: string) { if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { throw new Error( + // @ts-expect-error TS2339 `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` + `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` ) diff --git a/src/base/image_processors_utils.js b/src/base/image_processors_utils.ts similarity index 86% rename from src/base/image_processors_utils.js rename to src/base/image_processors_utils.ts index 685de28e5..9cf76c9c4 100644 --- a/src/base/image_processors_utils.js +++ b/src/base/image_processors_utils.ts @@ -3,15 +3,17 @@ import { Tensor, interpolate, stack } from "../utils/tensor.js"; import { bankers_round, max, min, softmax } from "../utils/maths.js"; import { RawImage } from "../utils/image.js"; import { calculateReflectOffset } from "../utils/core.js"; -import { getModelJSON } from "../utils/hub.js"; -import { IMAGE_PROCESSOR_NAME } from '../utils/constants.js'; +import { getModelJSON, PretrainedOptions } from "../utils/hub.js"; +import { IMAGE_PROCESSOR_NAME } from '../utils/constants'; /** * Named tuple to indicate the order we are using is (height x width), * even though the Graphics' industry standard is (width x height). * @typedef {[height: number, width: number]} HeightWidth */ +type HeightWidth = [height: number, width: number]; +type HeightWidthObject = { height: number; width: number; }; /** * @typedef {object} ImageProcessorResult @@ -19,7 +21,24 @@ import { IMAGE_PROCESSOR_NAME } from '../utils/constants.js'; * @property {HeightWidth[]} original_sizes Array of two-dimensional tuples like [[480, 640]]. * @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional tuples like [[1000, 1330]]. */ +interface ImageProcessorResult { + pixel_values: Tensor; + original_sizes: HeightWidth[]; + reshaped_input_sizes: HeightWidth[]; +} + +/** + * @typedef {object} PreprocessedImage + * @property {HeightWidth} original_size The original size of the image. + * @property {HeightWidth} reshaped_input_size The reshaped input size of the image. + * @property {Tensor} pixel_values The pixel values of the preprocessed image. + */ +interface PreprocessedImage { + original_size: HeightWidth; + reshaped_input_size: HeightWidth; + pixel_values: Tensor; +} /** @@ -31,7 +50,7 @@ import { IMAGE_PROCESSOR_NAME } from '../utils/constants.js'; * @returns {number} The constrained value. * @private */ -function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { +function constraint_to_multiple_of(val: number, multiple: number, minVal: number = 0, maxVal: number = null): number { const a = val / multiple; let x = bankers_round(a) * multiple; @@ -52,7 +71,7 @@ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { * @param {number} divisor The divisor to use. * @returns {[number, number]} The rounded size. */ -function enforce_size_divisibility([width, height], divisor) { +function enforce_size_divisibility([width, height]: [number, number], divisor: number): [number, number] { return [ Math.max(Math.floor(width / divisor), 1) * divisor, Math.max(Math.floor(height / divisor), 1) * divisor @@ -68,7 +87,7 @@ function enforce_size_divisibility([width, height], divisor) { * @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height) * @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) */ -export function center_to_corners_format([centerX, centerY, width, height]) { +export function center_to_corners_format([centerX, centerY, width, height]: number[]): number[] { return [ centerX - width / 2, centerY - height / 2, @@ -87,7 +106,7 @@ export function center_to_corners_format([centerX, centerY, width, height]) { * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed. * @return {Object[]} An array of objects containing the post-processed outputs. */ -export function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) { +export function post_process_object_detection(outputs: { logits: Tensor; pred_boxes: Tensor; }, threshold: number = 0.5, target_sizes: [number, number][] = null, is_zero_shot: boolean = false): object[] { const out_logits = outputs.logits; const out_bbox = outputs.pred_boxes; const [batch_size, num_boxes, num_classes] = out_logits.dims; @@ -110,7 +129,7 @@ export function post_process_object_detection(outputs, threshold = 0.5, target_s let logit = logits[j]; let indices = []; - let probs; + let probs: string | any[]; if (is_zero_shot) { // Get indices of classes with high enough probability probs = logit.sigmoid().data; @@ -140,8 +159,7 @@ export function post_process_object_detection(outputs, threshold = 0.5, target_s for (const index of indices) { // Some class has a high enough probability - /** @type {number[]} */ - let box = bbox[j].data; + let box: number[] = bbox[j].data; // convert to [x0, y0, x1, y1] format box = center_to_corners_format(box) @@ -167,7 +185,7 @@ export function post_process_object_detection(outputs, threshold = 0.5, target_s * (height, width) of each prediction. If unset, predictions will not be resized. * @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps. */ -export function post_process_semantic_segmentation(outputs, target_sizes = null) { +export function post_process_semantic_segmentation(outputs: any, target_sizes: [number, number][] = null): { segmentation: Tensor; labels: number[]; }[] { const logits = outputs.logits; const batch_size = logits.dims[0]; @@ -215,8 +233,8 @@ export function post_process_semantic_segmentation(outputs, target_sizes = null) const index = segmentation_data[j]; hasLabel[index] = index; } - /** @type {number[]} The unique list of labels that were detected */ - const labels = hasLabel.filter(x => x !== undefined); + /** The unique list of labels that were detected */ + const labels: number[] = hasLabel.filter(x => x !== undefined); toReturn.push({ segmentation, labels }); } @@ -233,7 +251,7 @@ export function post_process_semantic_segmentation(outputs, target_sizes = null) * @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels. * @private */ -function remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) { +function remove_low_and_no_objects(class_logits: Tensor, mask_logits: Tensor, object_mask_threshold: number, num_labels: number): [Tensor[], number[], number[]] { const mask_probs_item = []; const pred_scores_item = []; @@ -272,12 +290,12 @@ function remove_low_and_no_objects(class_logits, mask_logits, object_mask_thresh * @private */ function check_segment_validity( - mask_labels, - mask_probs, - k, - mask_threshold = 0.5, - overlap_mask_area_threshold = 0.8 -) { + mask_labels: Int32Array, + mask_probs: Tensor[], + k: number, + mask_threshold: number = 0.5, + overlap_mask_area_threshold: number = 0.8 +): [boolean, number[]] { // mask_k is a 1D array of indices, indicating where the mask is equal to k const mask_k = []; let mask_k_area = 0; @@ -316,19 +334,19 @@ function check_segment_validity( * @param {number} mask_threshold The mask threshold. * @param {number} overlap_mask_area_threshold The overlap mask area threshold. * @param {Set} label_ids_to_fuse The label ids to fuse. - * @param {number[]} target_size The target size of the image. + * @param {[number, number]} target_size The target size of the image. * @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments. * @private */ function compute_segments( - mask_probs, - pred_scores, - pred_labels, - mask_threshold, - overlap_mask_area_threshold, - label_ids_to_fuse = null, - target_size = null, -) { + mask_probs: Tensor[], + pred_scores: number[], + pred_labels: number[], + mask_threshold: number, + overlap_mask_area_threshold: number, + label_ids_to_fuse: Set = null, + target_size: [number, number] = null, +): [Tensor, Array<{ id: number; label_id: number; score: number; }>] { const [height, width] = target_size ?? mask_probs[0].dims; const segmentation = new Tensor( @@ -436,7 +454,7 @@ function compute_segments( * @returns {[number, number]} The new height and width of the image. * @throws {Error} If the height or width is smaller than the factor. */ -function smart_resize(height, width, factor = 28, min_pixels = 56 * 56, max_pixels = 14 * 14 * 4 * 1280) { +function smart_resize(height: number, width: number, factor: number = 28, min_pixels: number = 56 * 56, max_pixels: number = 14 * 14 * 4 * 1280): [number, number] { if (height < factor || width < factor) { throw new Error(`height:${height} or width:${width} must be larger than factor:${factor}`); @@ -474,13 +492,13 @@ function smart_resize(height, width, factor = 28, min_pixels = 56 * 56, max_pixe * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} */ export function post_process_panoptic_segmentation( - outputs, - threshold = 0.5, - mask_threshold = 0.5, - overlap_mask_area_threshold = 0.8, - label_ids_to_fuse = null, - target_sizes = null, -) { + outputs: any, + threshold: number = 0.5, + mask_threshold: number = 0.5, + overlap_mask_area_threshold: number = 0.8, + label_ids_to_fuse: Set = null, + target_sizes: [number, number][] = null, +): Array<{ segmentation: Tensor; segments_info: Array<{ id: number; label_id: number; score: number; }>; }> { if (label_ids_to_fuse === null) { console.warn("`label_ids_to_fuse` unset. No instance will be fused.") label_ids_to_fuse = new Set(); @@ -553,11 +571,10 @@ export function post_process_panoptic_segmentation( * (height, width) of each prediction. If unset, predictions will not be resized. * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>} */ -export function post_process_instance_segmentation(outputs, threshold = 0.5, target_sizes = null) { +export function post_process_instance_segmentation(outputs: any, threshold: number = 0.5, target_sizes: [number, number][] = null): Array<{ segmentation: Tensor; segments_info: Array<{ id: number; label_id: number; score: number; }>; }> { throw new Error('`post_process_instance_segmentation` is not yet implemented.'); } - /** * @typedef {Object} ImageProcessorConfig A configuration object used to create an image processor. * @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. @@ -568,8 +585,8 @@ export function post_process_instance_segmentation(outputs, threshold = 0.5, tar * @property {boolean} [do_normalize] Whether to normalize the image pixel values. * @property {boolean} [do_resize] Whether to resize the image. * @property {number} [resample] What method to use for resampling. - * @property {number|Object} [size] The size to resize the image to. - * @property {number|Object} [image_size] The size to resize the image to (same as `size`). + * @property {number|{ height: number; width: number;}} [size] The size to resize the image to. + * @property {number|{ height: number; width: number;}} [image_size] The size to resize the image to (same as `size`). * @property {boolean} [do_flip_channel_order=false] Whether to flip the color channels from RGB to BGR. * Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. * @property {boolean} [do_center_crop] Whether to center crop the image to the specified `crop_size`. @@ -583,14 +600,51 @@ export function post_process_instance_segmentation(outputs, threshold = 0.5, tar * @property {number[]} [mean] The mean values for image normalization (same as `image_mean`). * @property {number[]} [std] The standard deviation values for image normalization (same as `image_std`). */ +interface ImageProcessorConfig { + progress_callback?: Function; + image_mean?: number[]; + image_std?: number[]; + do_rescale?: boolean; + rescale_factor?: number; + do_normalize?: boolean; + do_resize?: boolean; + resample?: number; + size?: number | HeightWidthObject; + image_size?: number | HeightWidthObject; + do_flip_channel_order?: boolean; + do_center_crop?: boolean; + do_thumbnail?: boolean; + keep_aspect_ratio?: boolean; + ensure_multiple_of?: number; + mean?: number[]; + std?: number[]; +} export class ImageProcessor extends Callable { + image_mean: number[]; + image_std: number[]; + resample: number; + do_rescale: boolean; + rescale_factor: number; + do_normalize: boolean; + do_thumbnail: boolean; + size: number | HeightWidthObject; + do_resize: boolean; + do_center_crop: boolean; + do_pad: boolean; + pad_size: any; + do_flip_channel_order: boolean; + config: ImageProcessorConfig; + size_divisibility: any; + do_crop_margin: boolean; + do_convert_rgb: boolean; + crop_size: number | HeightWidthObject; /** * Constructs a new `ImageProcessor`. * @param {ImageProcessorConfig} config The configuration object. */ - constructor(config) { + constructor(config: ImageProcessorConfig) { super(); this.image_mean = config.image_mean ?? config.mean; @@ -620,7 +674,7 @@ export class ImageProcessor extends Callable { // @ts-expect-error TS2339 this.do_pad = config.do_pad; - if (this.do_pad && !this.pad_size && this.size && this.size.width !== undefined && this.size.height !== undefined) { + if (this.do_pad && !this.pad_size && this.size && typeof this.size === 'object' && this.size.width !== undefined && this.size.height !== undefined) { // Should pad, but no pad size specified // We infer the pad size from the resize size this.pad_size = this.size @@ -639,7 +693,7 @@ export class ImageProcessor extends Callable { * @param {string | 0 | 1 | 2 | 3 | 4 | 5} [resample=2] The resampling filter to use. * @returns {Promise} The resized image. */ - async thumbnail(image, size, resample = 2) { + async thumbnail(image: RawImage, size: HeightWidthObject, resample: string | 0 | 1 | 2 | 3 | 4 | 5 = 2): Promise { const input_height = image.height; const input_width = image.width; @@ -668,7 +722,7 @@ export class ImageProcessor extends Callable { * @param {number} gray_threshold Value below which pixels are considered to be gray. * @returns {Promise} The cropped image. */ - async crop_margin(image, gray_threshold = 200) { + async crop_margin(image: RawImage, gray_threshold: number = 200): Promise { const gray_image = image.clone().grayscale(); @@ -712,14 +766,14 @@ export class ImageProcessor extends Callable { * @param {number|number[]} [options.constant_values=0] The constant value to use for padding. * @returns {[Float32Array, number[]]} The padded pixel data and image dimensions. */ - pad_image(pixelData, imgDims, padSize, { + pad_image(pixelData: Float32Array, imgDims: number[], padSize: { width: number; height: number; } | number | 'square', { mode = 'constant', center = false, constant_values = 0, - } = {}) { + }: { mode?: 'constant' | 'symmetric'; center?: boolean; constant_values?: number | number[]; } = {}): [Float32Array, number[]] { const [imageHeight, imageWidth, imageChannels] = imgDims; - let paddedImageWidth, paddedImageHeight; + let paddedImageWidth: number, paddedImageHeight: number; if (typeof padSize === 'number') { paddedImageWidth = padSize; paddedImageHeight = padSize; @@ -796,7 +850,7 @@ export class ImageProcessor extends Callable { * @param {Float32Array} pixelData The pixel data to rescale. * @returns {void} */ - rescale(pixelData) { + rescale(pixelData: Float32Array): void { for (let i = 0; i < pixelData.length; ++i) { pixelData[i] = this.rescale_factor * pixelData[i]; } @@ -809,14 +863,14 @@ export class ImageProcessor extends Callable { * @param {any} size The size to use for resizing the image. * @returns {[number, number]} The target (width, height) dimension of the output image after resizing. */ - get_resize_output_image_size(image, size) { + get_resize_output_image_size(image: RawImage, size: any): [number, number] { // `size` comes in many forms, so we need to handle them all here: // 1. `size` is an integer, in which case we resize the image to be a square const [srcWidth, srcHeight] = image.size; - let shortest_edge; - let longest_edge; + let shortest_edge: number; + let longest_edge: number; if (this.do_thumbnail) { // NOTE: custom logic for `Donut` models @@ -908,7 +962,7 @@ export class ImageProcessor extends Callable { * @param {RawImage} image The image to resize. * @returns {Promise} The resized image. */ - async resize(image) { + async resize(image: RawImage): Promise { const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size); return await image.resize(newWidth, newHeight, { // @ts-expect-error TS2322 @@ -916,12 +970,7 @@ export class ImageProcessor extends Callable { }); } - /** - * @typedef {object} PreprocessedImage - * @property {HeightWidth} original_size The original size of the image. - * @property {HeightWidth} reshaped_input_size The reshaped input size of the image. - * @property {Tensor} pixel_values The pixel values of the preprocessed image. - */ + /** * Preprocesses the given image. @@ -930,13 +979,19 @@ export class ImageProcessor extends Callable { * @param {Object} overrides The overrides for the preprocessing options. * @returns {Promise} The preprocessed image. */ - async preprocess(image, { + async preprocess(image: RawImage, { do_normalize = null, do_pad = null, do_convert_rgb = null, do_convert_grayscale = null, do_flip_channel_order = null, - } = {}) { + }: { + do_normalize?: boolean; + do_pad?: boolean; + do_convert_rgb?: boolean; + do_convert_grayscale?: boolean; + do_flip_channel_order?: boolean; + } = {}): Promise { if (this.do_crop_margin) { // NOTE: Specific to nougat processors. This is done before resizing, // and can be interpreted as a pre-preprocessing step. @@ -968,9 +1023,9 @@ export class ImageProcessor extends Callable { if (this.do_center_crop) { - let crop_width; - let crop_height; - if (Number.isInteger(this.crop_size)) { + let crop_width: number; + let crop_height: number; + if (typeof this.crop_size === 'number') { crop_width = this.crop_size; crop_height = this.crop_size; } else { @@ -981,14 +1036,12 @@ export class ImageProcessor extends Callable { image = await image.center_crop(crop_width, crop_height); } - /** @type {HeightWidth} */ - const reshaped_input_size = [image.height, image.width]; + const reshaped_input_size: HeightWidth = [image.height, image.width]; // NOTE: All pixel-level manipulation (i.e., modifying `pixelData`) // occurs with data in the hwc format (height, width, channels), // to emulate the behavior of the original Python code (w/ numpy). - /** @type {Float32Array} */ - let pixelData = Float32Array.from(image.data); + let pixelData: Float32Array = Float32Array.from(image.data); let imgDims = [image.height, image.width, image.channels]; if (this.do_rescale) { @@ -1058,12 +1111,11 @@ export class ImageProcessor extends Callable { * @param {...any} args Additional arguments. * @returns {Promise} An object containing the concatenated pixel values (and other metadata) of the preprocessed images. */ - async _call(images, ...args) { + async _call(images: RawImage[], ...args: any[]): Promise { if (!Array.isArray(images)) { images = [images]; } - /** @type {PreprocessedImage[]} */ - const imageData = await Promise.all(images.map(x => this.preprocess(x))); + const imageData: PreprocessedImage[] = await Promise.all(images.map(x => this.preprocess(x))); // Stack pixel values const pixel_values = stack(imageData.map(x => x.pixel_values), 0); @@ -1091,11 +1143,11 @@ export class ImageProcessor extends Callable { * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. - * @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. + * @param {import('../utils/hub').PretrainedOptions} options Additional options for loading the processor. * * @returns {Promise} A new instance of the Processor class. */ - static async from_pretrained(pretrained_model_name_or_path, options) { + static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedOptions): Promise { const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, true, options); return new this(preprocessorConfig); } diff --git a/src/base/processing_utils.js b/src/base/processing_utils.ts similarity index 64% rename from src/base/processing_utils.js rename to src/base/processing_utils.ts index adf442f32..69c218718 100644 --- a/src/base/processing_utils.js +++ b/src/base/processing_utils.ts @@ -1,4 +1,3 @@ - /** * @file Processors are used to prepare inputs (e.g., text, image or audio) for a model. * @@ -19,18 +18,31 @@ * * @module processors */ -import { PROCESSOR_NAME } from '../utils/constants.js'; +import { PROCESSOR_NAME } from '../utils/constants'; import { Callable, -} from '../utils/generic.js'; -import { getModelJSON } from '../utils/hub.js'; - -/** - * @typedef {Object} ProcessorProperties Additional processor-specific properties. - * @typedef {import('../utils/hub.js').PretrainedOptions & ProcessorProperties} PretrainedProcessorOptions - * @typedef {import('../tokenizers.js').PreTrainedTokenizer} PreTrainedTokenizer - */ +} from '../utils/generic'; +import { getModelJSON } from '../utils/hub'; +import { PreTrainedTokenizer } from '../tokenizers'; +import { ImageProcessor } from './image_processors_utils'; +import { FeatureExtractor } from './feature_extraction_utils'; + +type ProcessorComponents = { + image_processor?: ImageProcessor; + tokenizer?: PreTrainedTokenizer; + feature_extractor?: FeatureExtractor; +}; + +interface ProcessorProperties { + revision?: string; + cache_dir?: string; + local_files_only?: boolean; + trust_remote_code?: boolean; +} +type PretrainedProcessorOptions = ProcessorProperties & { + [key: string]: any; +}; /** * Represents a Processor that extracts features from an input. @@ -40,38 +52,46 @@ export class Processor extends Callable { 'image_processor_class', 'tokenizer_class', 'feature_extractor_class', - ] + ] as const; static uses_processor_config = false; + // Add static type for component classes + static image_processor_class?: typeof ImageProcessor; + static tokenizer_class?: typeof PreTrainedTokenizer; + static feature_extractor_class?: typeof FeatureExtractor; + + config: object; + components: ProcessorComponents; + /** * Creates a new Processor with the given components * @param {Object} config - * @param {Record} components + * @param {ProcessorComponents} components */ - constructor(config, components) { + constructor(config: object, components: ProcessorComponents) { super(); this.config = config; this.components = components; } /** - * @returns {import('./image_processors_utils.js').ImageProcessor|undefined} The image processor of the processor, if it exists. + * @returns {ImageProcessor|undefined} The image processor of the processor, if it exists. */ - get image_processor() { + get image_processor(): ImageProcessor | undefined { return this.components.image_processor; } /** * @returns {PreTrainedTokenizer|undefined} The tokenizer of the processor, if it exists. */ - get tokenizer() { + get tokenizer(): PreTrainedTokenizer | undefined { return this.components.tokenizer; } /** - * @returns {import('./feature_extraction_utils.js').FeatureExtractor|undefined} The feature extractor of the processor, if it exists. + * @returns {FeatureExtractor|undefined} The feature extractor of the processor, if it exists. */ - get feature_extractor() { + get feature_extractor(): FeatureExtractor | undefined { return this.components.feature_extractor; } @@ -80,7 +100,7 @@ export class Processor extends Callable { * @param {Parameters[1]} options * @returns {ReturnType} */ - apply_chat_template(messages, options = {}) { + apply_chat_template(messages: Parameters[0], options: Parameters[1] = {}): ReturnType { if (!this.tokenizer) { throw new Error('Unable to apply chat template without a tokenizer.'); } @@ -94,7 +114,7 @@ export class Processor extends Callable { * @param {Parameters} args * @returns {ReturnType} */ - batch_decode(...args) { + batch_decode(...args: Parameters): ReturnType { if (!this.tokenizer) { throw new Error('Unable to decode without a tokenizer.'); } @@ -105,7 +125,7 @@ export class Processor extends Callable { * @param {Parameters} args * @returns {ReturnType} */ - decode(...args) { + decode(...args: Parameters): ReturnType { if (!this.tokenizer) { throw new Error('Unable to decode without a tokenizer.'); } @@ -119,7 +139,7 @@ export class Processor extends Callable { * @param {...any} args Additional arguments. * @returns {Promise} A Promise that resolves with the extracted features. */ - async _call(input, ...args) { + async _call(input: any, ...args: any[]): Promise { for (const item of [this.image_processor, this.feature_extractor, this.tokenizer]) { if (item) { return item(input, ...args); @@ -144,18 +164,23 @@ export class Processor extends Callable { * * @returns {Promise} A new instance of the Processor class. */ - static async from_pretrained(pretrained_model_name_or_path, options) { - + static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedProcessorOptions): Promise { + type ComponentClass = typeof ImageProcessor | typeof PreTrainedTokenizer | typeof FeatureExtractor; + type ComponentClassKey = typeof Processor.classes[number]; + const [config, components] = await Promise.all([ - // TODO: this.uses_processor_config ? getModelJSON(pretrained_model_name_or_path, PROCESSOR_NAME, true, options) : {}, Promise.all( this.classes - .filter((cls) => cls in this) + .filter((cls): cls is ComponentClassKey => + cls in this && + typeof this[cls as keyof typeof this] === 'function' + ) .map(async (cls) => { - const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options); + const ComponentClass = this[cls as keyof typeof this] as ComponentClass; + const component = await ComponentClass.from_pretrained(pretrained_model_name_or_path, options); return [cls.replace(/_class$/, ''), component]; }) ).then(Object.fromEntries) diff --git a/src/configs.js b/src/configs.ts similarity index 77% rename from src/configs.js rename to src/configs.ts index 4303169f3..c533b4412 100644 --- a/src/configs.js +++ b/src/configs.ts @@ -1,4 +1,3 @@ - /** * @file Helper module for using model configs. For more information, see the corresponding * [Python documentation](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoConfig). @@ -27,42 +26,98 @@ * @module configs */ -import { pick } from './utils/core.js'; -import { - getModelJSON, -} from './utils/hub.js'; +import { pick } from './utils/core'; +import { getModelJSON } from './utils/hub'; +import { DEVICE_TYPES } from './utils/devices'; +import { DATA_TYPES } from './utils/dtypes'; + +/** + * @typedef {import('./utils/hub').PretrainedOptions} PretrainedOptions + */ +import type { PretrainedOptions } from './utils/hub'; /** - * @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions + * @typedef {import('./utils/core').ProgressCallback} ProgressCallback */ +export type { ProgressCallback } from './utils/core'; /** - * @typedef {import('./utils/core.js').ProgressCallback} ProgressCallback + * @typedef {import('./utils/core').ProgressInfo} ProgressInfo */ +export type { ProgressInfo } from './utils/core'; /** - * @typedef {import('./utils/core.js').ProgressInfo} ProgressInfo + * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`. + * @typedef {Object} TransformersJSConfig + * @property {import('./utils/tensor').DataType|Record} [kv_cache_dtype] The data type of the key-value cache. + * @property {Record} [free_dimension_overrides] Override the free dimensions of the model. + * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides + * for more information. + * @property {import('./utils/devices').DeviceType} [device] The default device to use for the model. + * @property {import('./utils/dtypes').DataType|Record} [dtype] The default data type to use for the model. + * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). */ +export type TransformersJSConfig = { + kv_cache_dtype?: keyof typeof DATA_TYPES | Record; + free_dimension_overrides?: Record; + device?: keyof typeof DEVICE_TYPES; + dtype?: keyof typeof DATA_TYPES | Record; + use_external_data_format?: boolean | Record; +}; + +// Add interface for config object +export interface ConfigObject { + model_type: string; + is_encoder_decoder: boolean; + text_config?: ConfigObject; + phi_config?: ConfigObject; + decoder?: ConfigObject; + language_config?: ConfigObject; + [key: string]: any; +} + +// Add interface for normalized config +export interface NormalizedConfig { + model_type: string; + is_encoder_decoder: boolean; + num_heads?: number; + num_layers?: number; + hidden_size?: number; + num_attention_heads?: number; + num_key_value_heads?: number; + head_dim?: number; + num_decoder_layers?: number; + num_decoder_heads?: number; + decoder_hidden_size?: number; + num_encoder_layers?: number; + num_encoder_heads?: number; + encoder_hidden_size?: number; + encoder_dim_kv?: number; + decoder_dim_kv?: number; + dim_kv?: number; + multi_query?: boolean; + [key: string]: any; +} /** * Loads a config from the specified path. * @param {string} pretrained_model_name_or_path The path to the config directory. * @param {PretrainedOptions} options Additional options for loading the config. - * @returns {Promise} A promise that resolves with information about the loaded config. + * @returns {Promise} A promise that resolves with information about the loaded config. */ -async function loadConfig(pretrained_model_name_or_path, options) { - return await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options); +async function loadConfig(pretrained_model_name_or_path: string, options: PretrainedOptions): Promise { + return await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options) as ConfigObject; } /** * - * @param {PretrainedConfig} config - * @returns {Object} The normalized configuration. + * @param {ConfigObject} config + * @returns {NormalizedConfig} The normalized configuration. */ -function getNormalizedConfig(config) { - const mapping = {}; +function getNormalizedConfig(config: ConfigObject): NormalizedConfig { + const mapping: Record = {}; - let init_normalized_config = {}; + let init_normalized_config: Partial = {}; switch (config.model_type) { // Sub-configs case 'llava': @@ -70,20 +125,16 @@ function getNormalizedConfig(config) { case 'florence2': case 'llava_onevision': case 'idefics3': - // @ts-expect-error TS2339 - init_normalized_config = getNormalizedConfig(config.text_config); + init_normalized_config = getNormalizedConfig(config.text_config!); break; case 'moondream1': - // @ts-expect-error TS2339 - init_normalized_config = getNormalizedConfig(config.phi_config); + init_normalized_config = getNormalizedConfig(config.phi_config!); break; case 'musicgen': - // @ts-expect-error TS2339 - init_normalized_config = getNormalizedConfig(config.decoder); + init_normalized_config = getNormalizedConfig(config.decoder!); break; case 'multi_modality': - // @ts-expect-error TS2339 - init_normalized_config = getNormalizedConfig(config.language_config); + init_normalized_config = getNormalizedConfig(config.language_config!); break; // Decoder-only models @@ -210,11 +261,10 @@ function getNormalizedConfig(config) { mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size'; break; case 'vision-encoder-decoder': - // @ts-expect-error TS2339 - const decoderConfig = getNormalizedConfig(config.decoder); + const decoderConfig = getNormalizedConfig(config.decoder!); const add_encoder_pkv = 'num_decoder_layers' in decoderConfig; - const result = pick(config, ['model_type', 'is_encoder_decoder']); + const result: NormalizedConfig = pick(config, ['model_type', 'is_encoder_decoder']); if (add_encoder_pkv) { // Decoder is part of an encoder-decoder model result.num_decoder_layers = decoderConfig.num_decoder_layers; @@ -231,7 +281,6 @@ function getNormalizedConfig(config) { result.hidden_size = decoderConfig.hidden_size; } return result; - } // NOTE: If `num_attention_heads` is not set, it is assumed to be equal to `num_heads` @@ -250,12 +299,12 @@ function getNormalizedConfig(config) { * @param {PretrainedConfig} config * @returns {Record} */ -export function getKeyValueShapes(config, { +export function getKeyValueShapes(config: PretrainedConfig, { prefix = 'past_key_values', - batch_size=1, -} = {}) { + batch_size = 1, +} = {}): Record { /** @type {Record} */ - const decoderFeeds = {}; + const decoderFeeds: Record = {}; const normalized_config = config.normalized_config; if (normalized_config.is_encoder_decoder && ( @@ -324,32 +373,36 @@ export function getKeyValueShapes(config, { return decoderFeeds; } + /** * Base class for all configuration classes. For more information, see the corresponding * [Python documentation](https://huggingface.co/docs/transformers/main/en/main_classes/configuration#transformers.PretrainedConfig). */ export class PretrainedConfig { - // NOTE: Typo in original + // NOTE: Typo in original /** @type {string|null} */ - model_type = null; + model_type: string | null = null; /** @type {boolean} */ - is_encoder_decoder = false; + is_encoder_decoder: boolean = false; /** @type {number} */ - max_position_embeddings; + max_position_embeddings!: number; /** @type {TransformersJSConfig} */ - 'transformers.js_config'; + 'transformers.js_config'!: TransformersJSConfig; + + /** @type {NormalizedConfig} */ + normalized_config!: NormalizedConfig; /** * Create a new PreTrainedTokenizer instance. - * @param {Object} configJSON The JSON of the config. + * @param {ConfigObject} configJSON The JSON of the config. */ - constructor(configJSON) { + constructor(configJSON: ConfigObject) { Object.assign(this, configJSON); - this.normalized_config = getNormalizedConfig(this); + this.normalized_config = getNormalizedConfig(this as ConfigObject); } /** @@ -361,15 +414,15 @@ export class PretrainedConfig { * * @returns {Promise} A new instance of the `PretrainedConfig` class. */ - static async from_pretrained(pretrained_model_name_or_path, { + static async from_pretrained(pretrained_model_name_or_path: string, { progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', - } = {}) { + }: PretrainedOptions = {}): Promise { if (config && !(config instanceof PretrainedConfig)) { - config = new PretrainedConfig(config); + config = new PretrainedConfig(config as ConfigObject); } const data = config ?? await loadConfig(pretrained_model_name_or_path, { @@ -378,7 +431,7 @@ export class PretrainedConfig { cache_dir, local_files_only, revision, - }) + }); return new this(data); } } @@ -391,19 +444,7 @@ export class PretrainedConfig { */ export class AutoConfig { /** @type {typeof PretrainedConfig.from_pretrained} */ - static async from_pretrained(...args) { - return PretrainedConfig.from_pretrained(...args); + static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedOptions = {}): Promise { + return PretrainedConfig.from_pretrained(pretrained_model_name_or_path, options); } } - -/** - * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`. - * @typedef {Object} TransformersJSConfig - * @property {import('./utils/tensor.js').DataType|Record} [kv_cache_dtype] The data type of the key-value cache. - * @property {Record} [free_dimension_overrides] Override the free dimensions of the model. - * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides - * for more information. - * @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model. - * @property {import('./utils/dtypes.js').DataType|Record} [dtype] The default data type to use for the model. - * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). - */ diff --git a/src/env.js b/src/env.ts similarity index 84% rename from src/env.js rename to src/env.ts index f351c47f8..b2692a882 100644 --- a/src/env.js +++ b/src/env.ts @@ -25,12 +25,14 @@ import fs from 'fs'; import path from 'path'; import url from 'url'; +import { Env } from 'onnxruntime-common'; +import { ICache } from './utils/hub'; const VERSION = '3.3.3'; // Check if various APIs are available (depends on environment) const IS_BROWSER_ENV = typeof window !== "undefined" && typeof window.document !== "undefined"; -const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope'; +const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope'; const IS_WEB_CACHE_AVAILABLE = typeof self !== "undefined" && 'caches' in self; const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator; @@ -43,38 +45,42 @@ const IS_PATH_AVAILABLE = !isEmpty(path); /** * A read-only object containing information about the APIs available in the current environment. */ -export const apis = Object.freeze({ +type APIs = { /** Whether we are running in a browser environment (and not a web worker) */ - IS_BROWSER_ENV, - + IS_BROWSER_ENV: boolean; /** Whether we are running in a web worker environment */ - IS_WEBWORKER_ENV, - + IS_WEBWORKER_ENV: boolean; /** Whether the Cache API is available */ - IS_WEB_CACHE_AVAILABLE, - + IS_WEB_CACHE_AVAILABLE: boolean; /** Whether the WebGPU API is available */ - IS_WEBGPU_AVAILABLE, - + IS_WEBGPU_AVAILABLE: boolean; /** Whether the WebNN API is available */ - IS_WEBNN_AVAILABLE, - + IS_WEBNN_AVAILABLE: boolean; /** Whether the Node.js process API is available */ - IS_PROCESS_AVAILABLE, - + IS_PROCESS_AVAILABLE: boolean; /** Whether we are running in a Node.js environment */ - IS_NODE_ENV, - + IS_NODE_ENV: boolean; /** Whether the filesystem API is available */ - IS_FS_AVAILABLE, - + IS_FS_AVAILABLE: boolean; /** Whether the path API is available */ + IS_PATH_AVAILABLE: boolean; +}; + +export const apis: Readonly = Object.freeze({ + IS_BROWSER_ENV, + IS_WEBWORKER_ENV, + IS_WEB_CACHE_AVAILABLE, + IS_WEBGPU_AVAILABLE, + IS_WEBNN_AVAILABLE, + IS_PROCESS_AVAILABLE, + IS_NODE_ENV, + IS_FS_AVAILABLE, IS_PATH_AVAILABLE, }); const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE; -let dirname__ = './'; +let dirname__: string = './'; if (RUNNING_LOCALLY) { // NOTE: We wrap `import.meta` in a call to `Object` to prevent Webpack from trying to bundle it in CommonJS. // Although we get the warning: "Accessing import.meta directly is unsupported (only property access or destructuring is supported)", @@ -121,8 +127,25 @@ const localModelPath = RUNNING_LOCALLY * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache */ -/** @type {TransformersEnvironment} */ -export const env = { +export interface TransformersEnvironment { + version: string; + backends: { + onnx: Partial; + }; + allowRemoteModels: boolean; + remoteHost: string; + remotePathTemplate: string; + allowLocalModels: boolean; + localModelPath: string; + useFS: boolean; + useBrowserCache: boolean; + useFSCache: boolean; + cacheDir: string | null; + useCustomCache: boolean; + customCache: ICache | null; +} + +export const env: TransformersEnvironment = { version: VERSION, /////////////////// Backends settings /////////////////// @@ -152,11 +175,10 @@ export const env = { ////////////////////////////////////////////////////// } - /** * @param {Object} obj * @private */ -function isEmpty(obj) { +function isEmpty(obj: Record): boolean { return Object.keys(obj).length === 0; } diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.ts similarity index 85% rename from src/generation/configuration_utils.js rename to src/generation/configuration_utils.ts index 8474057da..58629c984 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.ts @@ -1,4 +1,3 @@ - /** * @module generation/configuration_utils */ @@ -17,14 +16,14 @@ export class GenerationConfig { * @type {number} * @default 20 */ - max_length = 20; + max_length: number = 20; /** * The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. * @type {number} * @default null */ - max_new_tokens = null; + max_new_tokens: number = null; /** * The minimum length of the sequence to be generated. @@ -33,14 +32,14 @@ export class GenerationConfig { * @type {number} * @default 0 */ - min_length = 0; + min_length: number = 0; /** * The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. * @type {number} * @default null */ - min_new_tokens = null; + min_new_tokens: number = null; /** * Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: @@ -50,7 +49,7 @@ export class GenerationConfig { * @type {boolean|"never"} * @default false */ - early_stopping = false; + early_stopping: boolean | "never" = false; /** * The maximum amount of time you allow the computation to run for in seconds. @@ -58,7 +57,7 @@ export class GenerationConfig { * @type {number} * @default null */ - max_time = null; + max_time: number = null; // Parameters that control the generation strategy used /** @@ -66,14 +65,14 @@ export class GenerationConfig { * @type {boolean} * @default false */ - do_sample = false; + do_sample: boolean = false; /** * Number of beams for beam search. 1 means no beam search. * @type {number} * @default 1 */ - num_beams = 1; + num_beams: number = 1; /** * Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. @@ -81,21 +80,21 @@ export class GenerationConfig { * @type {number} * @default 1 */ - num_beam_groups = 1; + num_beam_groups: number = 1; /** * The values balance the model confidence and the degeneration penalty in contrastive search decoding. * @type {number} * @default null */ - penalty_alpha = null; + penalty_alpha: number = null; /** * Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. * @type {boolean} * @default true */ - use_cache = true; + use_cache: boolean = true; // Parameters for manipulation of the model output logits /** @@ -103,21 +102,21 @@ export class GenerationConfig { * @type {number} * @default 1.0 */ - temperature = 1.0; + temperature: number = 1.0; /** * The number of highest probability vocabulary tokens to keep for top-k-filtering. * @type {number} * @default 50 */ - top_k = 50; + top_k: number = 50; /** * If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. * @type {number} * @default 1.0 */ - top_p = 1.0; + top_p: number = 1.0; /** * Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. @@ -126,7 +125,7 @@ export class GenerationConfig { * @type {number} * @default 1.0 */ - typical_p = 1.0; + typical_p: number = 1.0; /** * If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. @@ -135,7 +134,7 @@ export class GenerationConfig { * @type {number} * @default 0.0 */ - epsilon_cutoff = 0.0; + epsilon_cutoff: number = 0.0; /** * Eta sampling is a hybrid of locally typical sampling and epsilon sampling. @@ -145,7 +144,7 @@ export class GenerationConfig { * @type {number} * @default 0.0 */ - eta_cutoff = 0.0; + eta_cutoff: number = 0.0; /** * This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. @@ -153,7 +152,7 @@ export class GenerationConfig { * @type {number} * @default 0.0 */ - diversity_penalty = 0.0; + diversity_penalty: number = 0.0; /** * The parameter for repetition penalty. 1.0 means no penalty. @@ -161,7 +160,7 @@ export class GenerationConfig { * @type {number} * @default 1.0 */ - repetition_penalty = 1.0; + repetition_penalty: number = 1.0; /** * The paramater for encoder_repetition_penalty. @@ -170,7 +169,7 @@ export class GenerationConfig { * @type {number} * @default 1.0 */ - encoder_repetition_penalty = 1.0; + encoder_repetition_penalty: number = 1.0; /** * Exponential penalty to the length that is used with beam-based generation. @@ -179,14 +178,14 @@ export class GenerationConfig { * @type {number} * @default 1.0 */ - length_penalty = 1.0; + length_penalty: number = 1.0; /** * If set to int > 0, all ngrams of that size can only occur once. * @type {number} * @default 0 */ - no_repeat_ngram_size = 0; + no_repeat_ngram_size: number = 0; /** * List of token ids that are not allowed to be generated. @@ -195,7 +194,7 @@ export class GenerationConfig { * @type {number[][]} * @default null */ - bad_words_ids = null; + bad_words_ids: number[][] = null; /** * List of token ids that must be generated. @@ -204,7 +203,7 @@ export class GenerationConfig { * @type {number[][]|number[][][]} * @default null */ - force_words_ids = null; + force_words_ids: number[][] | number[][][] = null; /** * Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). @@ -212,14 +211,14 @@ export class GenerationConfig { * @type {boolean} * @default false */ - renormalize_logits = false; + renormalize_logits: boolean = false; /** * Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. * @type {Object[]} * @default null */ - constraints = null; + constraints: object[] = null; /** * The id of the token to force as the first generated token after the `decoder_start_token_id`. @@ -227,7 +226,7 @@ export class GenerationConfig { * @type {number} * @default null */ - forced_bos_token_id = null; + forced_bos_token_id: number = null; /** * The id of the token to force as the last generated token when `max_length` is reached. @@ -235,13 +234,13 @@ export class GenerationConfig { * @type {number|number[]} * @default null */ - forced_eos_token_id = null; + forced_eos_token_id: number | number[] = null; /** * Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. * @type {boolean} */ - remove_invalid_values = false; + remove_invalid_values: boolean = false; /** * This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. @@ -249,7 +248,7 @@ export class GenerationConfig { * @type {[number, number]} * @default null */ - exponential_decay_length_penalty = null; + exponential_decay_length_penalty: [number, number] = null; /** * A list of tokens that will be suppressed at generation. @@ -257,14 +256,14 @@ export class GenerationConfig { * @type {number[]} * @default null */ - suppress_tokens = null; + suppress_tokens: number[] = null; /** * A streamer that will be used to stream the generation. - * @type {import('./streamers.js').TextStreamer} + * @type {import('./streamers').TextStreamer} * @default null */ - streamer = null; + streamer: import('./streamers').TextStreamer = null; /** * A list of tokens that will be suppressed at the beginning of the generation. @@ -272,7 +271,7 @@ export class GenerationConfig { * @type {number[]} * @default null */ - begin_suppress_tokens = null; + begin_suppress_tokens: number[] = null; /** * A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. @@ -280,7 +279,7 @@ export class GenerationConfig { * @type {[number, number][]} * @default null */ - forced_decoder_ids = null; + forced_decoder_ids: [number, number][] = null; /** * The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. @@ -289,7 +288,7 @@ export class GenerationConfig { * @type {number} * @default null */ - guidance_scale = null; + guidance_scale: number = null; // Parameters that define the output variables of `generate` /** @@ -297,7 +296,7 @@ export class GenerationConfig { * @type {number} * @default 1 */ - num_return_sequences = 1; + num_return_sequences: number = 1; /** * Whether or not to return the attentions tensors of all attention layers. @@ -305,7 +304,7 @@ export class GenerationConfig { * @type {boolean} * @default false */ - output_attentions = false; + output_attentions: boolean = false; /** * Whether or not to return the hidden states of all layers. @@ -313,7 +312,7 @@ export class GenerationConfig { * @type {boolean} * @default false */ - output_hidden_states = false; + output_hidden_states: boolean = false; /** * Whether or not to return the prediction scores. @@ -321,14 +320,14 @@ export class GenerationConfig { * @type {boolean} * @default false */ - output_scores = false; + output_scores: boolean = false; /** * Whether or not to return a `ModelOutput` instead of a plain tuple. * @type {boolean} * @default false */ - return_dict_in_generate = false; + return_dict_in_generate: boolean = false; // Special tokens that can be used at generation time /** @@ -336,14 +335,14 @@ export class GenerationConfig { * @type {number} * @default null */ - pad_token_id = null; + pad_token_id: number = null; /** * The id of the *beginning-of-sequence* token. * @type {number} * @default null */ - bos_token_id = null; + bos_token_id: number = null; /** * The id of the *end-of-sequence* token. @@ -351,7 +350,7 @@ export class GenerationConfig { * @type {number|number[]} * @default null */ - eos_token_id = null; + eos_token_id: number | number[] = null; // Generation parameters exclusive to encoder-decoder models /** @@ -359,14 +358,14 @@ export class GenerationConfig { * @type {number} * @default 0 */ - encoder_no_repeat_ngram_size = 0; + encoder_no_repeat_ngram_size: number = 0; /** * If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. * @type {number} * @default null */ - decoder_start_token_id = null; + decoder_start_token_id: number = null; // Wild card /** @@ -375,14 +374,14 @@ export class GenerationConfig { * @type {Object} * @default {} */ - generation_kwargs = {}; + generation_kwargs: object = {}; /** * - * @param {GenerationConfig|import('../configs.js').PretrainedConfig} config + * @param {GenerationConfig|import('../configs').PretrainedConfig} config */ - constructor(config) { - Object.assign(this, pick(config, Object.getOwnPropertyNames(this))); + constructor(config: GenerationConfig | import('../configs').PretrainedConfig) { + Object.assign(this, pick(config as GenerationConfig, Object.getOwnPropertyNames(this) as Array)); } } diff --git a/src/generation/logits_process.js b/src/generation/logits_process.ts similarity index 82% rename from src/generation/logits_process.js rename to src/generation/logits_process.ts index f82634f75..81fa9af3c 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.ts @@ -20,7 +20,7 @@ export class LogitsProcessor extends Callable { * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor) { throw Error("`_call` should be implemented in a subclass") } } @@ -38,7 +38,7 @@ export class LogitsWarper extends Callable { * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor) { throw Error("`_call` should be implemented in a subclass") } } @@ -50,6 +50,12 @@ export class LogitsWarper extends Callable { * batch of logits. */ export class LogitsProcessorList extends Callable { + /** + * The list of logits processors. + * @type {LogitsProcessor[]} + */ + processors: LogitsProcessor[]; + /** * Constructs a new instance of `LogitsProcessorList`. */ @@ -63,7 +69,7 @@ export class LogitsProcessorList extends Callable { * * @param {LogitsProcessor} item The logits processor function to add. */ - push(item) { + push(item: LogitsProcessor) { this.processors.push(item); } @@ -72,7 +78,7 @@ export class LogitsProcessorList extends Callable { * * @param {LogitsProcessor[]} items The logits processor functions to add. */ - extend(items) { + extend(items: LogitsProcessor[]) { this.processors.push(...items); } @@ -82,7 +88,7 @@ export class LogitsProcessorList extends Callable { * @param {bigint[][]} input_ids The input IDs for the language model. * @param {Tensor} logits */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor) { let toReturn = logits; // NOTE: Most processors modify logits inplace for (const processor of this.processors) { @@ -138,11 +144,17 @@ export class LogitsProcessorList extends Callable { * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence. */ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { + /** + * The ID of the beginning-of-sequence token to be forced. + * @type {number} + */ + bos_token_id: number; + /** * Create a ForcedBOSTokenLogitsProcessor. * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced. */ - constructor(bos_token_id) { + constructor(bos_token_id: number) { super(); this.bos_token_id = bos_token_id; } @@ -153,7 +165,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The logits with BOS token forcing. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { if (input_ids[i].length === 1) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); @@ -169,12 +181,24 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { * A logits processor that enforces the specified token as the last generated token when `max_length` is reached. */ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { + /** + * The maximum length of the sequence to be generated. + * @type {number} + */ + max_length: number; + + /** + * The id(s) of the *end-of-sequence* token. + * @type {number[]} + */ + eos_token_id: number[]; + /** * Create a ForcedEOSTokenLogitsProcessor. * @param {number} max_length The maximum length of the sequence to be generated. * @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token. */ - constructor(max_length, eos_token_id) { + constructor(max_length: number, eos_token_id: number | number[]) { super(); this.max_length = max_length; this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; @@ -186,10 +210,10 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { * @param {bigint[][]} input_ids The input ids. * @param {Tensor} logits The logits tensor. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor) { for (let i = 0; i < input_ids.length; ++i) { if (input_ids[i].length === this.max_length - 1) { - const batch_logits_data = /** @type {Float32Array} */(logits[i].data); + const batch_logits_data: Float32Array = logits[i].data; batch_logits_data.fill(-Infinity); for (const eos_token of this.eos_token_id) { batch_logits_data[eos_token] = 0; @@ -206,12 +230,24 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { * `begin_suppress_tokens` at not sampled at the begining of the generation. */ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { + /** + * The IDs of the tokens to suppress. + * @type {number[]} + */ + begin_suppress_tokens: number[]; + + /** + * The number of tokens to generate before suppressing tokens. + * @type {number} + */ + begin_index: number; + /** * Create a SuppressTokensAtBeginLogitsProcessor. * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress. * @param {number} begin_index The number of tokens to generate before suppressing tokens. */ - constructor(begin_suppress_tokens, begin_index) { + constructor(begin_suppress_tokens: number[], begin_index: number) { super(); this.begin_suppress_tokens = begin_suppress_tokens; this.begin_index = begin_index; @@ -223,7 +259,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The logits with BOS token forcing. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { if (input_ids[i].length === this.begin_index) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); @@ -241,11 +277,42 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { */ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { /** + * The ID of the *end-of-sequence* token. + * @type {number} + */ + eos_token_id: number; + + /** + * The ID of the token that indicates no timestamps are present. + * @type {number} + */ + no_timestamps_token_id: number; + + /** + * The ID of the first timestamp token. + * @type {number} + */ + timestamp_begin: number; + + /** + * The index of the first timestamp token. + * @type {number} + */ + begin_index: number; + + /** + * The maximum index of the first timestamp token. + * @type {number} + */ + max_initial_timestamp_index: number; + + /** + * The maximum index of the first timestamp token. * Constructs a new WhisperTimeStampLogitsProcessor. - * @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model. + * @param {import('../models/whisper/generation_whisper').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model. * @param {number[]} init_tokens The initial tokens of the input sequence. */ - constructor(generate_config, init_tokens) { + constructor(generate_config: import('../models/whisper/generation_whisper').WhisperGenerationConfig, init_tokens: number[]) { super(); this.eos_token_id = Array.isArray(generate_config.eos_token_id) @@ -268,7 +335,7 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits output by the model. * @returns {Tensor} The modified logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); @@ -318,11 +385,17 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { * A logits processor that disallows ngrams of a certain size to be repeated. */ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { + /** + * The no-repeat-ngram size. All ngrams of this size can only occur once. + * @type {number} + */ + no_repeat_ngram_size: number; + /** * Create a NoRepeatNGramLogitsProcessor. * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once. */ - constructor(no_repeat_ngram_size) { + constructor(no_repeat_ngram_size: number) { super(); this.no_repeat_ngram_size = no_repeat_ngram_size; } @@ -332,11 +405,11 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { * @param {bigint[]} prevInputIds List of previous input ids * @returns {Map} Map of generated n-grams */ - getNgrams(prevInputIds) { + getNgrams(prevInputIds: bigint[]): Map { const curLen = prevInputIds.length; /**@type {number[][]} */ - const ngrams = []; + const ngrams: number[][] = []; for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) { const ngram = []; for (let k = 0; k < this.no_repeat_ngram_size; ++k) { @@ -346,7 +419,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { } /** @type {Map} */ - const generatedNgram = new Map(); + const generatedNgram: Map = new Map(); for (const ngram of ngrams) { const prevNgram = ngram.slice(0, ngram.length - 1); const prevNgramKey = JSON.stringify(prevNgram); @@ -363,7 +436,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { * @param {bigint[]} prevInputIds List of previous input ids * @returns {number[]} Map of generated n-grams */ - getGeneratedNgrams(bannedNgrams, prevInputIds) { + getGeneratedNgrams(bannedNgrams: Map, prevInputIds: bigint[]): number[] { const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length); const banned = bannedNgrams.get(JSON.stringify(ngramIdx.map(Number))) ?? []; return banned; @@ -374,7 +447,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { * @param {bigint[]} prevInputIds List of previous input ids * @returns {number[]} Map of generated n-grams */ - calcBannedNgramTokens(prevInputIds) { + calcBannedNgramTokens(prevInputIds: bigint[]): number[] { const bannedTokens = []; if (prevInputIds.length + 1 < this.no_repeat_ngram_size) { // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet @@ -393,7 +466,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The logits with no-repeat-ngram processing. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); const bannedTokens = this.calcBannedNgramTokens(input_ids[i]); @@ -418,12 +491,18 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { */ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { /** - * Create a RepetitionPenaltyLogitsProcessor. - * @param {number} penalty The parameter for repetition penalty. + * The parameter for repetition penalty. * - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. * - Between 0.0 and 1.0 rewards previously generated tokens. + * @type {number} */ - constructor(penalty) { + penalty: number; + + /** + * Create a RepetitionPenaltyLogitsProcessor. + * @param {number} penalty The parameter for repetition penalty. + */ + constructor(penalty: number) { super(); this.penalty = penalty; } @@ -434,7 +513,7 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The logits with repetition penalty processing. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); for (const input_id of new Set(input_ids[i])) { @@ -455,12 +534,24 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { * A logits processor that enforces a minimum number of tokens. */ export class MinLengthLogitsProcessor extends LogitsProcessor { + /** + * The minimum length below which the score of `eos_token_id` is set to negative infinity. + * @type {number} + */ + min_length: number; + + /** + * The ID/IDs of the end-of-sequence token. + * @type {number[]} + */ + eos_token_id: number[]; + /** * Create a MinLengthLogitsProcessor. * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity. * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. */ - constructor(min_length, eos_token_id) { + constructor(min_length: number, eos_token_id: number | number[]) { super(); this.min_length = min_length; this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; @@ -472,7 +563,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { if (input_ids[i].length < this.min_length) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); @@ -491,13 +582,31 @@ export class MinLengthLogitsProcessor extends LogitsProcessor { * A logits processor that enforces a minimum number of new tokens. */ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { + /** + * The input tokens length. + * @type {number} + */ + prompt_length_to_skip: number; + + /** + * The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. + * @type {number} + */ + min_new_tokens: number; + + /** + * The ID/IDs of the end-of-sequence token. + * @type {number[]} + */ + eos_token_id: number[]; + /** * Create a MinNewTokensLengthLogitsProcessor. * @param {number} prompt_length_to_skip The input tokens length. * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. */ - constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) { + constructor(prompt_length_to_skip: number, min_new_tokens: number, eos_token_id: number | number[]) { super(); this.prompt_length_to_skip = prompt_length_to_skip; this.min_new_tokens = min_new_tokens; @@ -510,7 +619,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip; if (new_tokens_length < this.min_new_tokens) { @@ -526,12 +635,24 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { } export class NoBadWordsLogitsProcessor extends LogitsProcessor { + /** + * The list of bad words. + * @type {number[][]} + */ + bad_words_ids: number[][]; + + /** + * The ID/IDs of the end-of-sequence token. + * @type {number[]} + */ + eos_token_id: number[]; + /** * Create a `NoBadWordsLogitsProcessor`. * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. */ - constructor(bad_words_ids, eos_token_id) { + constructor(bad_words_ids: number[][], eos_token_id: number | number[]) { super(); this.bad_words_ids = bad_words_ids; this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; @@ -543,7 +664,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { for (let i = 0; i < input_ids.length; ++i) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); const ids = input_ids[i]; @@ -581,6 +702,11 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { * See [the paper](https://arxiv.org/abs/2306.05284) for more information. */ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { + /** + * The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + * @type {number} + */ + guidance_scale: number; /** * Create a `ClassifierFreeGuidanceLogitsProcessor`. @@ -588,7 +714,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { * Higher guidance scale encourages the model to generate samples that are more closely linked to the input * prompt, usually at the expense of poorer quality. */ - constructor(guidance_scale) { + constructor(guidance_scale: number) { super(); if (guidance_scale <= 1) { throw new Error( @@ -604,7 +730,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { if (logits.dims[0] !== 2 * input_ids.length) { throw new Error( `Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` + @@ -632,13 +758,19 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { * that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`]. */ export class TemperatureLogitsWarper extends LogitsWarper { + /** + * The temperature for temperature (exponential scaling output probability distribution). + * @type {number} + */ + temperature: number; + /** * Create a `TemperatureLogitsWarper`. * @param {number} temperature Strictly positive float value used to modulate the logits distribution. * A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting * all probability mass to the most likely token. */ - constructor(temperature) { + constructor(temperature: number) { super(); if (typeof temperature !== 'number' || temperature <= 0) { @@ -658,7 +790,7 @@ export class TemperatureLogitsWarper extends LogitsWarper { * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ - _call(input_ids, logits) { + _call(input_ids: bigint[][], logits: Tensor): Tensor { const batch_logits_data = /** @type {Float32Array} */(logits.data); for (let i = 0; i < batch_logits_data.length; ++i) { batch_logits_data[i] /= this.temperature; @@ -672,6 +804,24 @@ export class TemperatureLogitsWarper extends LogitsWarper { * Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. */ export class TopPLogitsWarper extends LogitsWarper { + /** + * The probability cutoff for top-p sampling. + * @type {number} + */ + top_p: number; + + /** + * The filter value for top-p sampling. + * @type {number} + */ + filter_value: number; + + /** + * The minimum number of tokens that cannot be filtered. + * @type {number} + */ + min_tokens_to_keep: number; + /** * Create a `TopPLogitsWarper`. * @param {number} top_p If set to < 1, only the smallest set of most probable tokens with @@ -680,10 +830,10 @@ export class TopPLogitsWarper extends LogitsWarper { * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. */ - constructor(top_p, { + constructor(top_p: number, { filter_value = -Infinity, min_tokens_to_keep = 1, - } = {}) { + }: { filter_value?: number; min_tokens_to_keep?: number; } = {}) { super(); if (top_p < 0 || top_p > 1.0) { throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`) @@ -703,6 +853,18 @@ export class TopPLogitsWarper extends LogitsWarper { * Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. */ export class TopKLogitsWarper extends LogitsWarper { + /** + * The number of top tokens to keep. + * @type {number} + */ + top_k: number; + + /** + * The filter value for top-k sampling. + * @type {number} + */ + filter_value: number; + /** * Create a `TopKLogitsWarper`. * @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation. @@ -710,10 +872,10 @@ export class TopKLogitsWarper extends LogitsWarper { * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. */ - constructor(top_k, { + constructor(top_k: number, { filter_value = -Infinity, min_tokens_to_keep = 1, - } = {}) { + }: { filter_value?: number; min_tokens_to_keep?: number; } = {}) { super(); if (!Number.isInteger(top_k) || top_k < 0) { throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`) diff --git a/src/generation/logits_sampler.js b/src/generation/logits_sampler.ts similarity index 86% rename from src/generation/logits_sampler.js rename to src/generation/logits_sampler.ts index 46b74e081..7adb47a0e 100644 --- a/src/generation/logits_sampler.js +++ b/src/generation/logits_sampler.ts @@ -9,18 +9,21 @@ import { Tensor, topk } from "../utils/tensor.js"; import { max, softmax, -} from '../utils/maths.js'; -import { GenerationConfig } from '../generation/configuration_utils.js'; +} from '../utils/maths'; +import type { GenerationConfig } from './configuration_utils'; +import type { DataArray } from '../transformers'; /** * Sampler is a base class for all sampling methods used for text generation. */ export class LogitsSampler extends Callable { + generation_config: GenerationConfig; + /** * Creates a new Sampler object with the specified generation config. * @param {GenerationConfig} generation_config The generation config. */ - constructor(generation_config) { + constructor(generation_config: GenerationConfig) { super(); this.generation_config = generation_config; } @@ -30,7 +33,7 @@ export class LogitsSampler extends Callable { * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ - async _call(logits) { + async _call(logits: Tensor): Promise<[bigint, number][]> { // Sample from logits, of dims [batch, sequence_length, vocab_size]. // If index is specified, sample from [batch, index, vocab_size]. return this.sample(logits); @@ -42,7 +45,7 @@ export class LogitsSampler extends Callable { * @throws {Error} If not implemented in subclass. * @returns {Promise<[bigint, number][]>} */ - async sample(logits) { + async sample(logits: Tensor): Promise<[bigint, number][]> { throw Error("sample should be implemented in subclasses.") } @@ -52,10 +55,10 @@ export class LogitsSampler extends Callable { * @param {number} index * @returns {Float32Array} */ - getLogits(logits, index) { + getLogits(logits: Tensor, index: number): Float32Array { let vocabSize = logits.dims.at(-1); - let logs = /** @type {Float32Array} */(logits.data); + let logs = logits.data as Float32Array; if (index === -1) { logs = logs.slice(-vocabSize); @@ -71,7 +74,7 @@ export class LogitsSampler extends Callable { * @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection. * @returns {number} The index of the selected item. */ - randomSelect(probabilities) { + randomSelect(probabilities: DataArray): number { // Return index of chosen item let sumProbabilities = 0; for (let i = 0; i < probabilities.length; ++i) { @@ -93,7 +96,7 @@ export class LogitsSampler extends Callable { * @param {GenerationConfig} generation_config An object containing options for the sampler. * @returns {LogitsSampler} A Sampler object. */ - static getSampler(generation_config) { + static getSampler(generation_config: GenerationConfig): LogitsSampler { // - *greedy decoding*: `num_beams=1` and `do_sample=False` // - *contrastive search*: `penalty_alpha>0` and `top_k>1` // - *multinomial sampling*: `num_beams=1` and `do_sample=True` @@ -127,7 +130,7 @@ class GreedySampler extends LogitsSampler { * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search). */ - async sample(logits) { + async sample(logits: Tensor): Promise<[bigint, number][]> { // NOTE: no need to do log_softmax here since we only take the maximum const argmax = max(logits.data)[1]; @@ -149,7 +152,7 @@ class MultinomialSampler extends LogitsSampler { * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ - async sample(logits) { + async sample(logits: Tensor): Promise<[bigint, number][]> { let k = logits.dims.at(-1); // defaults to vocab size if (this.generation_config.top_k > 0) { k = Math.min(this.generation_config.top_k, k); @@ -159,7 +162,7 @@ class MultinomialSampler extends LogitsSampler { const [v, i] = await topk(logits, k); // Compute softmax over logits - const probabilities = softmax(/** @type {Float32Array} */(v.data)); + const probabilities = softmax(v.data as Float32Array); return Array.from({ length: this.generation_config.num_beams }, () => { const sampledIndex = this.randomSelect(probabilities); @@ -182,7 +185,7 @@ class BeamSearchSampler extends LogitsSampler { * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ - async sample(logits) { + async sample(logits: Tensor): Promise<[bigint, number][]> { let k = logits.dims.at(-1); // defaults to vocab size if (this.generation_config.top_k > 0) { k = Math.min(this.generation_config.top_k, k); @@ -192,7 +195,7 @@ class BeamSearchSampler extends LogitsSampler { const [v, i] = await topk(logits, k); // Compute softmax over logits - const probabilities = softmax(/** @type {Float32Array} */(v.data)); + const probabilities = softmax(v.data as Float32Array); return Array.from({ length: this.generation_config.num_beams }, (_, x) => { return [ diff --git a/src/generation/parameters.js b/src/generation/parameters.js deleted file mode 100644 index 1e2f2def3..000000000 --- a/src/generation/parameters.js +++ /dev/null @@ -1,35 +0,0 @@ - -/** - * @module generation/parameters - */ - -/** - * @typedef {Object} GenerationFunctionParameters - * @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*): - * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the - * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - * `input_ids`, `input_values`, `input_features`, or `pixel_values`. - * @property {import('./configuration_utils.js').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*): - * The generation configuration to be used as base parametrization for the generation call. - * `**kwargs` passed to generate matching the attributes of `generation_config` will override them. - * If `generation_config` is not provided, the default will be used, which has the following loading - * priority: - * - (1) from the `generation_config.json` model file, if it exists; - * - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s - * default values, whose documentation should be checked to parameterize generation. - * @property {import('./logits_process.js').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*): - * Custom logits processors that complement the default logits processors built from arguments and - * generation config. If a logit processor is passed that is already created with the arguments or a - * generation config an error is thrown. This feature is intended for advanced users. - * @property {import('./stopping_criteria.js').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*): - * Custom stopping criteria that complements the default stopping criteria built from arguments and a - * generation config. If a stopping criteria is passed that is already created with the arguments or a - * generation config an error is thrown. This feature is intended for advanced users. - * @property {import('./streamers.js').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*): - * Streamer object that will be used to stream the generated sequences. Generated tokens are passed - * through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - * @property {number[]} [decoder_input_ids=null] (`number[]`, *optional*): - * If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`. - * @param {any} [kwargs] (`Dict[str, any]`, *optional*): - */ diff --git a/src/generation/parameters.ts b/src/generation/parameters.ts new file mode 100644 index 000000000..1d18155af --- /dev/null +++ b/src/generation/parameters.ts @@ -0,0 +1,94 @@ +/** + * @module generation/parameters + */ + +import type { Tensor } from '../utils/tensor'; +import type { GenerationConfig } from './configuration_utils'; +import type { LogitsProcessorList } from './logits_process'; +import type { StoppingCriteriaList } from './stopping_criteria'; +import type { BaseStreamer } from './streamers'; + +/** + * @typedef {Object} GenerationFunctionParameters + * @property {import('../utils/tensor').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*): + * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the + * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + * `input_ids`, `input_values`, `input_features`, or `pixel_values`. + * @property {import('./configuration_utils').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*): + * The generation configuration to be used as base parametrization for the generation call. + * `**kwargs` passed to generate matching the attributes of `generation_config` will override them. + * If `generation_config` is not provided, the default will be used, which has the following loading + * priority: + * - (1) from the `generation_config.json` model file, if it exists; + * - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s + * default values, whose documentation should be checked to parameterize generation. + * @property {import('./logits_process').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*): + * Custom logits processors that complement the default logits processors built from arguments and + * generation config. If a logit processor is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + * @property {import('./stopping_criteria').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*): + * Custom stopping criteria that complements the default stopping criteria built from arguments and a + * generation config. If a stopping criteria is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + * @property {import('./streamers').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*): + * Streamer object that will be used to stream the generated sequences. Generated tokens are passed + * through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + * @property {number[]} [decoder_input_ids=null] (`number[]`, *optional*): + * If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`. + * @param {any} [kwargs] (`Dict[str, any]`, *optional*): + */ + +/** + * Parameters for generation functions + */ +export interface GenerationFunctionParameters { + /** + * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the + * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + * `input_ids`, `input_values`, `input_features`, or `pixel_values`. + */ + inputs?: Tensor | null; + + /** + * The generation configuration to be used as base parametrization for the generation call. + * `**kwargs` passed to generate matching the attributes of `generation_config` will override them. + * If `generation_config` is not provided, the default will be used, which has the following loading + * priority: + * - (1) from the `generation_config.json` model file, if it exists; + * - (2) from the model configuration. Please note that unspecified parameters will inherit `GenerationConfig`'s + * default values, whose documentation should be checked to parameterize generation. + */ + generation_config?: GenerationConfig | null; + + /** + * Custom logits processors that complement the default logits processors built from arguments and + * generation config. If a logit processor is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + */ + logits_processor?: LogitsProcessorList | null; + + /** + * Custom stopping criteria that complements the default stopping criteria built from arguments and a + * generation config. If a stopping criteria is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + */ + stopping_criteria?: StoppingCriteriaList | null; + + /** + * Streamer object that will be used to stream the generated sequences. Generated tokens are passed + * through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + */ + streamer?: BaseStreamer | null; + + /** + * If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`. + */ + decoder_input_ids?: number[] | null; + + /** + * Additional keyword arguments + */ + [key: string]: any; +} diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.ts similarity index 85% rename from src/generation/stopping_criteria.js rename to src/generation/stopping_criteria.ts index 08434f2b4..5cda9c296 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.ts @@ -21,13 +21,18 @@ export class StoppingCriteria extends Callable { * or scores for each vocabulary token after SoftMax. * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped. */ - _call(input_ids, scores) { + _call(input_ids: number[][], scores: number[][]): boolean[] { throw Error("StoppingCriteria needs to be subclassed"); } } /** */ export class StoppingCriteriaList extends Callable { + /** + * @type {StoppingCriteria[]} + */ + criteria: StoppingCriteria[]; + /** * Constructs a new instance of `StoppingCriteriaList`. */ @@ -41,7 +46,7 @@ export class StoppingCriteriaList extends Callable { * * @param {StoppingCriteria} item The stopping criterion to add. */ - push(item) { + push(item: StoppingCriteria) { this.criteria.push(item); } @@ -50,7 +55,7 @@ export class StoppingCriteriaList extends Callable { * * @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add. */ - extend(items) { + extend(items: StoppingCriteria | StoppingCriteriaList | StoppingCriteria[]) { if (items instanceof StoppingCriteriaList) { items = items.criteria; } else if (items instanceof StoppingCriteria) { @@ -59,7 +64,7 @@ export class StoppingCriteriaList extends Callable { this.criteria.push(...items); } - _call(input_ids, scores) { + _call(input_ids: number[][], scores: number[][]): boolean[] { const is_done = new Array(input_ids.length).fill(false); for (const criterion of this.criteria) { const criterion_done = criterion(input_ids, scores); @@ -80,19 +85,21 @@ export class StoppingCriteriaList extends Callable { * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens. */ export class MaxLengthCriteria extends StoppingCriteria { + max_length: number; + max_position_embeddings: number | null; /** * * @param {number} max_length The maximum length that the output sequence can have in number of tokens. * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. */ - constructor(max_length, max_position_embeddings = null) { + constructor(max_length: number, max_position_embeddings: number = null) { super(); this.max_length = max_length; this.max_position_embeddings = max_position_embeddings; } - _call(input_ids) { + _call(input_ids: number[][]): boolean[] { return input_ids.map(ids => ids.length >= this.max_length); } } @@ -104,13 +111,13 @@ export class MaxLengthCriteria extends StoppingCriteria { * By default, it uses the `model.generation_config.eos_token_id`. */ export class EosTokenCriteria extends StoppingCriteria { - + eos_token_id: number[]; /** * * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. * Optionally, use a list to set multiple *end-of-sequence* tokens. */ - constructor(eos_token_id) { + constructor(eos_token_id: number | number[]) { super(); if (!Array.isArray(eos_token_id)) { eos_token_id = [eos_token_id]; @@ -124,7 +131,7 @@ export class EosTokenCriteria extends StoppingCriteria { * @param {number[][]} scores * @returns {boolean[]} */ - _call(input_ids, scores) { + _call(input_ids: number[][], scores: number[][]): boolean[] { return input_ids.map(ids => { const last = ids.at(-1); // NOTE: We use == instead of === to allow for number/bigint comparison @@ -137,6 +144,8 @@ export class EosTokenCriteria extends StoppingCriteria { * This class can be used to stop generation whenever the user interrupts the process. */ export class InterruptableStoppingCriteria extends StoppingCriteria { + interrupted: boolean; + constructor() { super(); this.interrupted = false; @@ -150,7 +159,7 @@ export class InterruptableStoppingCriteria extends StoppingCriteria { this.interrupted = false; } - _call(input_ids, scores) { + _call(input_ids: number[][], scores: number[][]): boolean[] { return new Array(input_ids.length).fill(this.interrupted); } } diff --git a/src/generation/streamers.js b/src/generation/streamers.ts similarity index 81% rename from src/generation/streamers.js rename to src/generation/streamers.ts index 33c882081..7d360788d 100644 --- a/src/generation/streamers.js +++ b/src/generation/streamers.ts @@ -3,16 +3,16 @@ * @module generation/streamers */ -import { mergeArrays } from '../utils/core.js'; -import { is_chinese_char } from '../tokenizers.js'; -import { apis } from '../env.js'; +import { mergeArrays } from '../utils/core'; +import { is_chinese_char, PreTrainedTokenizer } from '../tokenizers'; +import { apis } from '../env'; export class BaseStreamer { /** * Function that is called by `.generate()` to push new tokens * @param {bigint[][]} value */ - put(value) { + put(value: bigint[][]) { throw Error('Not implemented'); } @@ -25,16 +25,25 @@ export class BaseStreamer { } const stdout_write = apis.IS_PROCESS_AVAILABLE - ? x => process.stdout.write(x) - : x => console.log(x); + ? (x: string | Uint8Array) => process.stdout.write(x) + : (x: any) => console.log(x); /** * Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. */ export class TextStreamer extends BaseStreamer { + tokenizer: PreTrainedTokenizer; + skip_prompt: boolean; + callback_function: (arg0: string) => void; + token_callback_function: (arg0: bigint[]) => void; + decode_kwargs: object; + token_cache: bigint[]; + print_len: number; + next_tokens_are_prompt: boolean; + /** * - * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer + * @param {import('../tokenizers').PreTrainedTokenizer} tokenizer * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding @@ -42,14 +51,14 @@ export class TextStreamer extends BaseStreamer { * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method */ - constructor(tokenizer, { + constructor(tokenizer: PreTrainedTokenizer, { skip_prompt = false, callback_function = null, token_callback_function = null, skip_special_tokens = true, decode_kwargs = {}, ...kwargs - } = {}) { + }: { skip_prompt?: boolean; skip_special_tokens?: boolean; callback_function?: (arg0: string) => void; token_callback_function?: (arg0: bigint[]) => void; decode_kwargs?: object; } = {}) { super(); this.tokenizer = tokenizer; this.skip_prompt = skip_prompt; @@ -67,7 +76,7 @@ export class TextStreamer extends BaseStreamer { * Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. * @param {bigint[][]} value */ - put(value) { + put(value: bigint[][]) { if (value.length > 1) { throw Error('TextStreamer only supports batch size of 1'); } @@ -84,7 +93,7 @@ export class TextStreamer extends BaseStreamer { this.token_cache = mergeArrays(this.token_cache, tokens); const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs); - let printable_text; + let printable_text: string | string[]; if (text.endsWith('\n')) { // After the symbol for a new line, we flush the cache. printable_text = text.slice(this.print_len); @@ -108,7 +117,7 @@ export class TextStreamer extends BaseStreamer { * Flushes any remaining cache and prints a newline to stdout. */ end() { - let printable_text; + let printable_text: string; if (this.token_cache.length > 0) { const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs); printable_text = text.slice(this.print_len); @@ -126,7 +135,7 @@ export class TextStreamer extends BaseStreamer { * @param {string} text * @param {boolean} stream_end */ - on_finalized_text(text, stream_end) { + on_finalized_text(text: string, stream_end: boolean) { if (text.length > 0) { this.callback_function?.(text); } @@ -145,8 +154,15 @@ export class TextStreamer extends BaseStreamer { * - The stream is finalized (on_finalize) */ export class WhisperTextStreamer extends TextStreamer { + timestamp_begin: number; + on_chunk_start: (arg0: number) => void; + on_chunk_end: (arg0: number) => void; + on_finalize: () => void; + time_precision: number; + waiting_for_timestamp: boolean; + /** - * @param {import('../tokenizers.js').WhisperTokenizer} tokenizer + * @param {import('../tokenizers').WhisperTokenizer} tokenizer * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display @@ -158,7 +174,7 @@ export class WhisperTextStreamer extends TextStreamer { * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method */ - constructor(tokenizer, { + constructor(tokenizer: import('../tokenizers').WhisperTokenizer, { skip_prompt = false, callback_function = null, token_callback_function = null, @@ -168,7 +184,7 @@ export class WhisperTextStreamer extends TextStreamer { time_precision = 0.02, skip_special_tokens = true, decode_kwargs = {}, - } = {}) { + }: { skip_prompt?: boolean; callback_function?: (arg0: string) => void; token_callback_function?: (arg0: bigint[]) => void; on_chunk_start?: (arg0: number) => void; on_chunk_end?: (arg0: number) => void; on_finalize?: () => void; time_precision?: number; skip_special_tokens?: boolean; decode_kwargs?: object; } = {}) { super(tokenizer, { skip_prompt, skip_special_tokens, @@ -190,7 +206,7 @@ export class WhisperTextStreamer extends TextStreamer { /** * @param {bigint[][]} value */ - put(value) { + put(value: bigint[][]) { if (value.length > 1) { throw Error('WhisperTextStreamer only supports batch size of 1'); } diff --git a/src/models/beit/image_processing_beit.js b/src/models/beit/image_processing_beit.ts similarity index 100% rename from src/models/beit/image_processing_beit.js rename to src/models/beit/image_processing_beit.ts diff --git a/src/ops/registry.js b/src/ops/registry.ts similarity index 83% rename from src/ops/registry.js rename to src/ops/registry.ts index 4f2179bec..d44a008ed 100644 --- a/src/ops/registry.js +++ b/src/ops/registry.ts @@ -1,7 +1,7 @@ import { createInferenceSession, isONNXProxy } from "../backends/onnx.js"; import { Tensor } from "../utils/tensor.js"; import { apis } from "../env.js"; - +import type { InferenceSession } from "onnxruntime-common"; const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; /** * Asynchronously creates a wrapper function for running an ONNX inference session. @@ -14,15 +14,16 @@ const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; * @returns {Promise): Promise>} * The wrapper function for running the ONNX inference session. */ -const wrap = async (session_bytes, session_options, names) => { +const wrap = async (session_bytes: number[], session_options: InferenceSession.SessionOptions, names: T): Promise<((arg0: Record) => Promise)> => { const session = await createInferenceSession( - new Uint8Array(session_bytes), session_options, + new Uint8Array(session_bytes), session_options, {} ); - /** @type {Promise} */ - let chain = Promise.resolve(); + let chain: Promise = Promise.resolve(); - return /** @type {any} */(async (/** @type {Record} */ inputs) => { + return (async (inputs: Record) => { const proxied = isONNXProxy(); const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, (proxied ? v.clone() : v).ort_tensor])); @@ -32,18 +33,28 @@ const wrap = async (session_bytes, session_options, names) => { if (Array.isArray(names)) { return names.map((n) => new Tensor(outputs[n])); } else { - return new Tensor(outputs[/** @type {string} */(names)]); + return new Tensor(outputs[names]); } - }) + }) as any; } // In-memory registry of initialized ONNX operators export class TensorOpRegistry { + private static _nearest_interpolate_4d: Promise<(arg0: Record) => Promise> | undefined; + private static _bilinear_interpolate_4d: Promise<(arg0: Record) => Promise> | undefined; + private static _bicubic_interpolate_4d: Promise<(arg0: Record) => Promise> | undefined; + private static _matmul: Promise<(arg0: Record) => Promise> | undefined; + private static _stft: Promise<(arg0: Record) => Promise> | undefined; + private static _rfft: Promise<(arg0: Record) => Promise> | undefined; + private static _slice: Promise<(arg0: Record) => Promise> | undefined; + private static _top_k: Promise<(arg0: Record) => Promise> | undefined; + + static session_options = { // TODO: Allow for multiple execution providers // executionProviders: ['webgpu'], }; - + static get nearest_interpolate_4d() { if (!this._nearest_interpolate_4d) { this._nearest_interpolate_4d = wrap( @@ -114,8 +125,8 @@ export class TensorOpRegistry { this._top_k = wrap( [8, 10, 18, 0, 58, 73, 10, 18, 10, 1, 120, 10, 1, 107, 18, 1, 118, 18, 1, 105, 34, 4, 84, 111, 112, 75, 18, 1, 116, 90, 9, 10, 1, 120, 18, 4, 10, 2, 8, 1, 90, 15, 10, 1, 107, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 118, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 105, 18, 4, 10, 2, 8, 7, 66, 2, 16, 21], this.session_options, - [ /* Values */ 'v', /* Indices */ 'i'] - ) + ['v', 'i'] + ); } return this._top_k; } diff --git a/src/utils/audio.js b/src/utils/audio.ts similarity index 77% rename from src/utils/audio.js rename to src/utils/audio.ts index 54dc87008..1a078f533 100644 --- a/src/utils/audio.js +++ b/src/utils/audio.ts @@ -9,15 +9,42 @@ import { getFile, -} from './hub.js'; -import { FFT, max } from './maths.js'; +} from './hub'; +import { FFT, max } from './maths'; import { calculateReflectOffset, saveBlob, -} from './core.js'; -import { apis } from '../env.js'; +} from './core'; +import { apis } from '../env'; import fs from 'fs'; -import { Tensor, matmul } from './tensor.js'; +import { Tensor, matmul } from './tensor'; +import type { AnyTypedArray } from './maths'; +import type { DataType } from './tensor'; + +export interface SpectrogramOptions { + fft_length?: number | null; + power?: number | null; + center?: boolean; + pad_mode?: 'reflect' | 'constant' | 'edge'; + onesided?: boolean; + preemphasis?: number | null; + mel_filters?: number[][] | null; + mel_floor?: number; + log_mel?: 'log' | 'log10' | 'dB' | null; + reference?: number; + min_value?: number; + db_range?: number | null; + remove_dc_offset?: boolean | null; + min_num_frames?: number | null; + max_num_frames?: number | null; + do_pad?: boolean; + transpose?: boolean; +} +export interface WindowFunctionOptions { + periodic?: boolean; + frame_length?: number | null; + center?: boolean; +} /** * Helper function to read audio from a path/URL. @@ -25,7 +52,7 @@ import { Tensor, matmul } from './tensor.js'; * @param {number} sampling_rate The sampling rate to use when decoding the audio. * @returns {Promise} The decoded audio as a `Float32Array`. */ -export async function read_audio(url, sampling_rate) { +export async function read_audio(url: string | URL, sampling_rate?: number): Promise { if (typeof AudioContext === 'undefined') { // Running in node or an environment without AudioContext throw Error( @@ -35,15 +62,15 @@ export async function read_audio(url, sampling_rate) { ) } - const response = await (await getFile(url)).arrayBuffer(); + const file = await getFile(url); + const response = await file.arrayBuffer(); const audioCTX = new AudioContext({ sampleRate: sampling_rate }); if (typeof sampling_rate === 'undefined') { console.warn(`No sampling rate provided, using default of ${audioCTX.sampleRate}Hz.`) } - const decoded = await audioCTX.decodeAudioData(response); + const decoded = await audioCTX.decodeAudioData(response as ArrayBuffer); - /** @type {Float32Array} */ - let audio; + let audio: Float32Array; // We now replicate HuggingFace's `ffmpeg_read` method: if (decoded.numberOfChannels === 2) { @@ -87,7 +114,7 @@ export async function read_audio(url, sampling_rate) { * @param {number} a_0 Offset for the generalized cosine window. * @returns {Float64Array} The generated window. */ -function generalized_cosine_window(M, a_0) { +function generalized_cosine_window(M: number, a_0: number): Float64Array { if (M < 1) { return new Float64Array(); } @@ -112,7 +139,7 @@ function generalized_cosine_window(M, a_0) { * @param {number} M The length of the Hanning window to generate. * @returns {Float64Array} The generated Hanning window. */ -export function hanning(M) { +export function hanning(M: number): Float64Array { return generalized_cosine_window(M, 0.5); } @@ -124,15 +151,15 @@ export function hanning(M) { * @param {number} M The length of the Hamming window to generate. * @returns {Float64Array} The generated Hamming window. */ -export function hamming(M) { +export function hamming(M: number): Float64Array { return generalized_cosine_window(M, 0.54); } -const HERTZ_TO_MEL_MAPPING = { - "htk": (/** @type {number} */ freq) => 2595.0 * Math.log10(1.0 + (freq / 700.0)), - "kaldi": (/** @type {number} */ freq) => 1127.0 * Math.log(1.0 + (freq / 700.0)), - "slaney": (/** @type {number} */ freq, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = 27.0 / Math.log(6.4)) => +const HERTZ_TO_MEL_MAPPING: Record number> = { + "htk": (freq: number) => 2595.0 * Math.log10(1.0 + (freq / 700.0)), + "kaldi": (freq: number) => 1127.0 * Math.log(1.0 + (freq / 700.0)), + "slaney": (freq: number, min_log_hertz:number = 1000.0, min_log_mel:number = 15.0, logstep:number = 27.0 / Math.log(6.4)) => freq >= min_log_hertz ? min_log_mel + Math.log(freq / min_log_hertz) * logstep : 3.0 * freq / 200.0, @@ -144,19 +171,22 @@ const HERTZ_TO_MEL_MAPPING = { * @param {string} [mel_scale] * @returns {T} */ -function hertz_to_mel(freq, mel_scale = "htk") { +function hertz_to_mel(freq: T, mel_scale: keyof typeof HERTZ_TO_MEL_MAPPING = "htk"): T { const fn = HERTZ_TO_MEL_MAPPING[mel_scale]; if (!fn) { throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".'); } - return typeof freq === 'number' ? fn(freq) : freq.map(x => fn(x)); + if (typeof freq === 'number') { + return fn(freq) as T; + } + return freq.map(x => fn(x)) as T; } -const MEL_TO_HERTZ_MAPPING = { - "htk": (/** @type {number} */ mels) => 700.0 * (10.0 ** (mels / 2595.0) - 1.0), - "kaldi": (/** @type {number} */ mels) => 700.0 * (Math.exp(mels / 1127.0) - 1.0), - "slaney": (/** @type {number} */ mels, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = Math.log(6.4) / 27.0) => mels >= min_log_mel +const MEL_TO_HERTZ_MAPPING: Record number> = { + "htk": (mels: number) => 700.0 * (10.0 ** (mels / 2595.0) - 1.0), + "kaldi": (mels: number) => 700.0 * (Math.exp(mels / 1127.0) - 1.0), + "slaney": (mels: number, min_log_hertz = 1000.0, min_log_mel = 15.0, logstep = Math.log(6.4) / 27.0) => mels >= min_log_mel ? min_log_hertz * Math.exp(logstep * (mels - min_log_mel)) : 200.0 * mels / 3.0, } @@ -164,16 +194,19 @@ const MEL_TO_HERTZ_MAPPING = { /** * @template {Float32Array|Float64Array|number} T * @param {T} mels - * @param {string} [mel_scale] + * @param {keyof typeof MEL_TO_HERTZ_MAPPING} [mel_scale] * @returns {T} */ -function mel_to_hertz(mels, mel_scale = "htk") { +function mel_to_hertz(mels: T, mel_scale: keyof typeof MEL_TO_HERTZ_MAPPING = "htk"): T { const fn = MEL_TO_HERTZ_MAPPING[mel_scale]; if (!fn) { throw new Error('mel_scale should be one of "htk", "slaney" or "kaldi".'); } - return typeof mels === 'number' ? fn(mels) : mels.map(x => fn(x)); + if (typeof mels === 'number') { + return fn(mels) as T; + } + return mels.map(x => fn(x)) as T; } /** @@ -185,7 +218,7 @@ function mel_to_hertz(mels, mel_scale = "htk") { * @param {Float64Array} filter_freqs Center frequencies of the triangular filters to create, in Hz, of shape `(num_mel_filters,)`. * @returns {number[][]} of shape `(num_frequency_bins, num_mel_filters)`. */ -function _create_triangular_filter_bank(fft_freqs, filter_freqs) { +function _create_triangular_filter_bank(fft_freqs: Float64Array, filter_freqs: Float64Array): number[][] { const filter_diff = Float64Array.from( { length: filter_freqs.length - 1 }, (_, i) => filter_freqs[i + 1] - filter_freqs[i] @@ -223,7 +256,7 @@ function _create_triangular_filter_bank(fft_freqs, filter_freqs) { * @param {number} num Number of samples to generate. * @returns `num` evenly spaced samples, calculated over the interval `[start, stop]`. */ -function linspace(start, end, num) { +function linspace(start: number, end: number, num: number): Float64Array { const step = (end - start) / (num - 1); return Float64Array.from({ length: num }, (_, i) => start + step * i); } @@ -246,15 +279,15 @@ function linspace(start, end, num) { * This is a projection matrix to go from a spectrogram to a mel spectrogram. */ export function mel_filter_bank( - num_frequency_bins, - num_mel_filters, - min_frequency, - max_frequency, - sampling_rate, - norm = null, - mel_scale = "htk", - triangularize_in_mel_space = false, -) { + num_frequency_bins: number, + num_mel_filters: number, + min_frequency: number, + max_frequency: number, + sampling_rate: number, + norm: string = null, + mel_scale: string = "htk", + triangularize_in_mel_space: boolean = false, +): number[][] { if (norm !== null && norm !== "slaney") { throw new Error('norm must be one of null or "slaney"'); } @@ -302,7 +335,7 @@ export function mel_filter_bank( * @param {number} right The amount of padding to add to the right. * @returns {T} The padded array. */ -function padReflect(array, left, right) { +function padReflect(array: T, left: number, right: number): T { // @ts-ignore const padded = new array.constructor(array.length + left + right); const w = array.length - 1; @@ -332,7 +365,7 @@ function padReflect(array, left, right) { * @param {number} db_range * @returns {T} */ -function _db_conversion_helper(spectrogram, factor, reference, min_value, db_range) { +function _db_conversion_helper(spectrogram: T, factor: number, reference: number, min_value: number, db_range: number): T { if (reference <= 0) { throw new Error('reference must be greater than zero'); } @@ -380,7 +413,7 @@ function _db_conversion_helper(spectrogram, factor, reference, min_value, db_ran * difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero. * @returns {T} The modified spectrogram in decibels. */ -function amplitude_to_db(spectrogram, reference = 1.0, min_value = 1e-5, db_range = null) { +function amplitude_to_db(spectrogram: T, reference: number = 1.0, min_value: number = 1e-5, db_range: number | null = null): T { return _db_conversion_helper(spectrogram, 20.0, reference, min_value, db_range); } @@ -405,7 +438,7 @@ function amplitude_to_db(spectrogram, reference = 1.0, min_value = 1e-5, db_rang * difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero. * @returns {T} The modified spectrogram in decibels. */ -function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range = null) { +function power_to_db(spectrogram: T, reference: number = 1.0, min_value: number = 1e-10, db_range: number | null = null): T { return _db_conversion_helper(spectrogram, 10.0, reference, min_value, db_range); } @@ -429,45 +462,19 @@ function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range = * shorter than `frame_length`, but we're assuming the array has already been zero-padded. * @param {number} frame_length The length of the analysis frames in samples (a.k.a., `fft_length`). * @param {number} hop_length The stride between successive analysis frames in samples. - * @param {Object} options - * @param {number} [options.fft_length=null] The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have. - * For optimal speed, this should be a power of two. If `null`, uses `frame_length`. - * @param {number} [options.power=1.0] If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `null`, returns complex numbers. - * @param {boolean} [options.center=true] Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `false`, frame - * `t` will start at time `t * hop_length`. - * @param {string} [options.pad_mode="reflect"] Padding mode used when `center` is `true`. Possible values are: `"constant"` (pad with zeros), - * `"edge"` (pad with edge values), `"reflect"` (pads with mirrored values). - * @param {boolean} [options.onesided=true] If `true`, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` - * frequency bins. If `false`, also computes the negative frequencies and returns `fft_length` frequency bins. - * @param {number} [options.preemphasis=null] Coefficient for a low-pass filter that applies pre-emphasis before the DFT. - * @param {number[][]} [options.mel_filters=null] The mel filter bank of shape `(num_freq_bins, num_mel_filters)`. - * If supplied, applies this filter bank to create a mel spectrogram. - * @param {number} [options.mel_floor=1e-10] Minimum value of mel frequency banks. - * @param {string} [options.log_mel=null] How to convert the spectrogram to log scale. Possible options are: - * `null` (don't convert), `"log"` (take the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). - * Can only be used when `power` is not `null`. - * @param {number} [options.reference=1.0] Sets the input spectrogram value that corresponds to 0 dB. For example, use `max(spectrogram)[0]` to set - * the loudest part to 0 dB. Must be greater than zero. - * @param {number} [options.min_value=1e-10] The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking `log(0)`. - * For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an amplitude spectrogram, the value `1e-5` corresponds to -100 dB. - * Must be greater than zero. - * @param {number} [options.db_range=null] Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the - * peak value and the smallest value will never be more than 80 dB. Must be greater than zero. - * @param {boolean} [options.remove_dc_offset=null] Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in - * order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. - * @param {number} [options.max_num_frames=null] If provided, limits the number of frames to compute to this value. - * @param {number} [options.min_num_frames=null] If provided, ensures the number of frames to compute is at least this value. - * @param {boolean} [options.do_pad=true] If `true`, pads the output spectrogram to have `max_num_frames` frames. - * @param {boolean} [options.transpose=false] If `true`, the returned spectrogram will have shape `(num_frames, num_frequency_bins/num_mel_filters)`. If `false`, the returned spectrogram will have shape `(num_frequency_bins/num_mel_filters, num_frames)`. + * @param {SpectrogramOptions} options * @returns {Promise} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram). */ export async function spectrogram( - waveform, - window, - frame_length, - hop_length, - { - fft_length = null, + waveform: Float32Array | Float64Array, + window: Float32Array | Float64Array, + frame_length: number, + hop_length: number, + options: SpectrogramOptions = {} +): Promise { + const window_length = window.length; + const { + fft_length: fft_length_opt = null, power = 1.0, center = true, pad_mode = "reflect", @@ -480,15 +487,13 @@ export async function spectrogram( min_value = 1e-10, db_range = null, remove_dc_offset = null, - - // Custom parameters for efficiency reasons min_num_frames = null, max_num_frames = null, do_pad = true, transpose = false, - } = {} -) { - const window_length = window.length; + } = options; + + let fft_length = fft_length_opt; if (fft_length === null) { fft_length = frame_length; } @@ -607,15 +612,22 @@ export async function spectrogram( // TODO: What if `mel_filters` is null? const num_mel_filters = mel_filters.length; - // Perform matrix muliplication: - // mel_spec = mel_filters @ magnitudes.T - // - mel_filters.shape=(80, 201) - // - magnitudes.shape=(3000, 201) => magnitudes.T.shape=(201, 3000) - // - mel_spec.shape=(80, 3000) + // For the Tensor creation, convert arrays to Float32Array + if (!mel_filters) { + throw new Error('mel_filters must be provided'); + } + + // Create tensors with proper typed arrays + const mel_filters_flat = mel_filters.flat(); + const mel_filters_array = new Float32Array(mel_filters_flat.length); + mel_filters_array.set(mel_filters_flat); + + const transposed_magnitude_array = new Float32Array(transposedMagnitudeData.length); + transposed_magnitude_array.set(transposedMagnitudeData); + let mel_spec = await matmul( - // TODO: Make `mel_filters` a Tensor during initialization - new Tensor('float32', mel_filters.flat(), [num_mel_filters, num_frequency_bins]), - new Tensor('float32', transposedMagnitudeData, [num_frequency_bins, d1Max]), + new Tensor('float32', mel_filters_array, [num_mel_filters, num_frequency_bins]), + new Tensor('float32', transposed_magnitude_array, [num_frequency_bins, d1Max]), ); if (transpose) { mel_spec = mel_spec.transpose(1, 0); @@ -642,9 +654,9 @@ export async function spectrogram( break; case 'dB': if (power === 1.0) { - amplitude_to_db(mel_spec_data, reference, min_value, db_range); + amplitude_to_db(mel_spec_data as Float32Array, reference, min_value, db_range); } else if (power === 2.0) { - power_to_db(mel_spec_data, reference, min_value, db_range); + power_to_db(mel_spec_data as Float32Array, reference, min_value, db_range); } else { throw new Error(`Cannot use log_mel option '${log_mel}' with power ${power}`) } @@ -661,18 +673,15 @@ export async function spectrogram( * Returns an array containing the specified window. * @param {number} window_length The length of the window in samples. * @param {string} name The name of the window function. - * @param {Object} options Additional options. - * @param {boolean} [options.periodic=true] Whether the window is periodic or symmetric. - * @param {number} [options.frame_length=null] The length of the analysis frames in samples. - * Provide a value for `frame_length` if the window is smaller than the frame length, so that it will be zero-padded. - * @param {boolean} [options.center=true] Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided. + * @param {WindowFunctionOptions} options Additional options. * @returns {Float64Array} The window of shape `(window_length,)` or `(frame_length,)`. */ -export function window_function(window_length, name, { - periodic = true, - frame_length = null, - center = true, -} = {}) { +export function window_function(window_length: number, name: string, options: WindowFunctionOptions = {}): Float64Array { + const { + periodic = true, + frame_length = null, + center = true, + } = options; const length = periodic ? window_length + 1 : window_length; let window; switch (name) { @@ -714,7 +723,7 @@ export function window_function(window_length, name, { * @param {number} rate The sample rate. * @returns {ArrayBuffer} The WAV audio buffer. */ -function encodeWAV(samples, rate) { +function encodeWAV(samples: Float32Array, rate: number): ArrayBuffer { let offset = 44; const buffer = new ArrayBuffer(offset + samples.length * 4); const view = new DataView(buffer); @@ -753,7 +762,7 @@ function encodeWAV(samples, rate) { return buffer; } -function writeString(view, offset, string) { +function writeString(view: DataView, offset: number, string: string): void { for (let i = 0; i < string.length; ++i) { view.setUint8(offset + i, string.charCodeAt(i)); } @@ -761,30 +770,32 @@ function writeString(view, offset, string) { export class RawAudio { + private audio: Float32Array; + private sampling_rate: number; /** * Create a new `RawAudio` object. * @param {Float32Array} audio Audio data * @param {number} sampling_rate Sampling rate of the audio data */ - constructor(audio, sampling_rate) { - this.audio = audio - this.sampling_rate = sampling_rate + constructor(audio: Float32Array, sampling_rate: number) { + this.audio = audio; + this.sampling_rate = sampling_rate; } /** * Convert the audio to a wav file buffer. * @returns {ArrayBuffer} The WAV file. */ - toWav() { - return encodeWAV(this.audio, this.sampling_rate) + toWav(): ArrayBuffer { + return encodeWAV(this.audio, this.sampling_rate); } /** * Convert the audio to a blob. * @returns {Blob} */ - toBlob() { + toBlob(): Blob { const wav = this.toWav(); const blob = new Blob([wav], { type: 'audio/wav' }); return blob; @@ -794,23 +805,22 @@ export class RawAudio { * Save the audio to a wav file. * @param {string} path */ - async save(path) { - let fn; - + async save(path: string): Promise { if (apis.IS_BROWSER_ENV) { if (apis.IS_WEBWORKER_ENV) { - throw new Error('Unable to save a file from a Web Worker.') + throw new Error('Unable to save a file from a Web Worker.'); } - fn = saveBlob; + // Since saveBlob is synchronous, wrap it in a Promise + await new Promise((resolve) => { + saveBlob(path, this.toBlob()); + resolve(); + }); } else if (apis.IS_FS_AVAILABLE) { - fn = async (/** @type {string} */ path, /** @type {Blob} */ blob) => { - let buffer = await blob.arrayBuffer(); - fs.writeFileSync(path, Buffer.from(buffer)); - } + const blob = this.toBlob(); + const buffer = await blob.arrayBuffer(); + fs.writeFileSync(path, Buffer.from(buffer)); } else { - throw new Error('Unable to save because filesystem is disabled in this environment.') + throw new Error('Unable to save because filesystem is disabled in this environment.'); } - - await fn(path, this.toBlob()) } } diff --git a/src/utils/constants.js b/src/utils/constants.ts similarity index 100% rename from src/utils/constants.js rename to src/utils/constants.ts diff --git a/src/utils/core.js b/src/utils/core.ts similarity index 79% rename from src/utils/core.js rename to src/utils/core.ts index a74ee123c..d06600a3b 100644 --- a/src/utils/core.js +++ b/src/utils/core.ts @@ -1,4 +1,3 @@ - /** * @file Core utility functions/classes for Transformers.js. * @@ -14,6 +13,11 @@ * @property {string} name The model id or directory path. * @property {string} file The name of the file. */ +export interface InitiateProgressInfo { + status: 'initiate'; + name: string; // The model id or directory path + file: string; // The name of the file +} /** * @typedef {Object} DownloadProgressInfo @@ -21,6 +25,11 @@ * @property {string} name The model id or directory path. * @property {string} file The name of the file. */ +export interface DownloadProgressInfo { + status: 'download'; + name: string; + file: string; +} /** * @typedef {Object} ProgressStatusInfo @@ -31,6 +40,14 @@ * @property {number} loaded The number of bytes loaded. * @property {number} total The total number of bytes to be loaded. */ +export interface ProgressStatusInfo { + status: 'progress'; + name: string; + file: string; + progress: number; + loaded: number; + total: number; +} /** * @typedef {Object} DoneProgressInfo @@ -38,6 +55,11 @@ * @property {string} name The model id or directory path. * @property {string} file The name of the file. */ +export interface DoneProgressInfo { + status: 'done'; + name: string; + file: string; +} /** * @typedef {Object} ReadyProgressInfo @@ -45,10 +67,16 @@ * @property {string} task The loaded task. * @property {string} model The loaded model. */ +export interface ReadyProgressInfo { + status: 'ready'; + task: string; + model: string; +} /** * @typedef {InitiateProgressInfo | DownloadProgressInfo | ProgressStatusInfo | DoneProgressInfo | ReadyProgressInfo} ProgressInfo */ +export type ProgressInfo = InitiateProgressInfo | DownloadProgressInfo | ProgressStatusInfo | DoneProgressInfo | ReadyProgressInfo; /** * A callback function that is called with progress information. @@ -56,6 +84,7 @@ * @param {ProgressInfo} progressInfo * @returns {void} */ +export type ProgressCallback = (progressInfo: ProgressInfo) => void; /** * Helper function to dispatch progress callbacks. @@ -65,7 +94,7 @@ * @returns {void} * @private */ -export function dispatchCallback(progress_callback, data) { +export function dispatchCallback(progress_callback: ProgressCallback | null | undefined, data: ProgressInfo): void { if (progress_callback) progress_callback(data); } @@ -76,7 +105,7 @@ export function dispatchCallback(progress_callback, data) { * @returns {Object} The reversed object. * @see https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript */ -export function reverseDictionary(data) { +export function reverseDictionary(data: Record): Record { // https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key])); } @@ -87,7 +116,7 @@ export function reverseDictionary(data) { * @param {string} string The string to escape. * @returns {string} The escaped string. */ -export function escapeRegExp(string) { +export function escapeRegExp(string: string): string { return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string } @@ -98,7 +127,7 @@ export function escapeRegExp(string) { * * Adapted from https://stackoverflow.com/a/71091338/13989043 */ -export function isTypedArray(val) { +export function isTypedArray(val: any): boolean { return val?.prototype?.__proto__?.constructor?.name === 'TypedArray'; } @@ -108,7 +137,7 @@ export function isTypedArray(val) { * @param {*} x The value to check. * @returns {boolean} True if the value is a string, false otherwise. */ -export function isIntegralNumber(x) { +export function isIntegralNumber(x: any): boolean { return Number.isInteger(x) || typeof x === 'bigint' } @@ -117,7 +146,7 @@ export function isIntegralNumber(x) { * @param {*} x The value to check. * @returns {boolean} True if the value is `null`, `undefined` or `-1`, false otherwise. */ -export function isNullishDimension(x) { +export function isNullishDimension(x: any): boolean { return x === null || x === undefined || x === -1; } @@ -127,7 +156,7 @@ export function isNullishDimension(x) { * @param {any[]} arr The nested array to calculate dimensions for. * @returns {number[]} An array containing the dimensions of the input array. */ -export function calculateDimensions(arr) { +export function calculateDimensions(arr: any[]): number[] { const dimensions = []; let current = arr; while (Array.isArray(current)) { @@ -145,7 +174,7 @@ export function calculateDimensions(arr) { * @returns {*} The value of the popped key. * @throws {Error} If the key does not exist and no default value is provided. */ -export function pop(obj, key, defaultValue = undefined) { +export function pop(obj: Record, key: string, defaultValue?: T): T { const value = obj[key]; if (value !== undefined) { delete obj[key]; @@ -163,7 +192,7 @@ export function pop(obj, key, defaultValue = undefined) { * @param {Array[]} arrs Arrays to merge. * @returns {Array} The merged array. */ -export function mergeArrays(...arrs) { +export function mergeArrays(...arrs: T[][]): T[] { return Array.prototype.concat.apply([], arrs); } @@ -173,7 +202,7 @@ export function mergeArrays(...arrs) { * @returns {Array} Returns the computed Cartesian product as an array * @private */ -export function product(...a) { +export function product(...a: any[]): [] { // Cartesian product of items // Adapted from https://stackoverflow.com/a/43053803 return a.reduce((a, b) => a.flatMap(d => b.map(e => [d, e]))); @@ -185,7 +214,7 @@ export function product(...a) { * @param {number} w The window size. * @returns {number} The index offset. */ -export function calculateReflectOffset(i, w) { +export function calculateReflectOffset(i: number, w: number): number { return Math.abs((i + w) % (2 * w) - w); } @@ -194,7 +223,7 @@ export function calculateReflectOffset(i, w) { * @param {string} path The path to save the blob to * @param {Blob} blob The blob to save */ -export function saveBlob(path, blob){ +export function saveBlob(path: string, blob: Blob): void { // Convert the canvas content to a data URL const dataURL = URL.createObjectURL(blob); @@ -221,7 +250,7 @@ export function saveBlob(path, blob){ * @param {string[]} props * @returns {Object} */ -export function pick(o, props) { +export function pick, K extends keyof T>(o: T, props: K[]): Pick { return Object.assign( {}, ...props.map((prop) => { @@ -238,7 +267,7 @@ export function pick(o, props) { * @param {string} s The string to calculate the length of. * @returns {number} The length of the string. */ -export function len(s) { +export function len(s: string): number { let length = 0; for (const c of s) ++length; return length; @@ -250,7 +279,7 @@ export function len(s) { * @param {any[]|string} arr The array or string to search. * @param {any} value The value to count. */ -export function count(arr, value) { +export function count(arr: T[] | string, value: T): number { let count = 0; for (const v of arr) { if (v === value) ++count; diff --git a/src/utils/data-structures.js b/src/utils/data-structures.ts similarity index 85% rename from src/utils/data-structures.js rename to src/utils/data-structures.ts index 2340d12c0..3631f45db 100644 --- a/src/utils/data-structures.js +++ b/src/utils/data-structures.ts @@ -1,4 +1,3 @@ - /** * @file Custom data structures. * @@ -8,6 +7,10 @@ * @module utils/data-structures */ +/** + * Type for the comparator function used in PriorityQueue + */ +export type ComparatorFn = (a: T, b: T) => boolean; /** * Efficient Heap-based Implementation of a Priority Queue. @@ -18,13 +21,16 @@ * - https://stackoverflow.com/a/42919752/13989043 (original) * - https://github.com/belladoreai/llama-tokenizer-js (minor improvements) */ -export class PriorityQueue { +export class PriorityQueue { + private _heap: T[]; + private _comparator: ComparatorFn; + private _maxSize: number; /** * Create a new PriorityQueue. * @param {function(any, any): boolean} comparator Comparator function to determine priority. Defaults to a MaxHeap. */ - constructor(comparator = (a, b) => a > b, maxSize = Infinity) { + constructor(comparator: ComparatorFn = (a, b) => a > b, maxSize: number = Infinity) { this._heap = []; this._comparator = comparator; this._maxSize = maxSize; @@ -33,7 +39,7 @@ export class PriorityQueue { /** * The size of the queue */ - get size() { + get size(): number { return this._heap.length; } @@ -41,7 +47,7 @@ export class PriorityQueue { * Check if the queue is empty. * @returns {boolean} `true` if the queue is empty, `false` otherwise. */ - isEmpty() { + isEmpty(): boolean { return this.size === 0; } @@ -49,7 +55,7 @@ export class PriorityQueue { * Return the element with the highest priority in the queue. * @returns {any} The highest priority element in the queue. */ - peek() { + peek(): T | undefined { return this._heap[0]; } @@ -58,7 +64,7 @@ export class PriorityQueue { * @param {...any} values The values to push into the queue. * @returns {number} The new size of the queue. */ - push(...values) { + push(...values: T[]): number { return this.extend(values); } @@ -67,7 +73,7 @@ export class PriorityQueue { * @param {any[]} values The values to push into the queue. * @returns {number} The new size of the queue. */ - extend(values) { + extend(values: T[]): number { for (const value of values) { if (this.size < this._maxSize) { this._heap.push(value); @@ -91,7 +97,7 @@ export class PriorityQueue { * Remove and return the element with the highest priority in the queue. * @returns {any} The element with the highest priority in the queue. */ - pop() { + pop(): T | undefined { const poppedValue = this.peek(); const bottom = this.size - 1; if (bottom > 0) { @@ -107,7 +113,7 @@ export class PriorityQueue { * @param {*} value The new value. * @returns {*} The replaced value. */ - replace(value) { + replace(value: T): T | undefined { const replacedValue = this.peek(); this._heap[0] = value; this._siftDown(); @@ -120,7 +126,7 @@ export class PriorityQueue { * @returns {number} The index of the parent node. * @private */ - _parent(i) { + private _parent(i: number): number { return ((i + 1) >>> 1) - 1; } @@ -130,7 +136,7 @@ export class PriorityQueue { * @returns {number} The index of the left child. * @private */ - _left(i) { + private _left(i: number): number { return (i << 1) + 1; } @@ -140,7 +146,7 @@ export class PriorityQueue { * @returns {number} The index of the right child. * @private */ - _right(i) { + private _right(i: number): number { return (i + 1) << 1; } @@ -151,7 +157,7 @@ export class PriorityQueue { * @returns {boolean} `true` if the element at index `i` is greater than the element at index `j`, `false` otherwise. * @private */ - _greater(i, j) { + private _greater(i: number, j: number): boolean { return this._comparator(this._heap[i], this._heap[j]); } @@ -161,7 +167,7 @@ export class PriorityQueue { * @param {number} j The index of the second element to swap. * @private */ - _swap(i, j) { + private _swap(i: number, j: number): void { const temp = this._heap[i]; this._heap[i] = this._heap[j]; this._heap[j] = temp; @@ -172,7 +178,7 @@ export class PriorityQueue { * starting at the last element and moving up the heap. * @private */ - _siftUp() { + private _siftUp(): void { this._siftUpFrom(this.size - 1); } @@ -180,7 +186,7 @@ export class PriorityQueue { * Helper function to sift up from a given node. * @param {number} node The index of the node to start sifting up from. */ - _siftUpFrom(node) { + private _siftUpFrom(node: number): void { while (node > 0 && this._greater(node, this._parent(node))) { this._swap(node, this._parent(node)); node = this._parent(node); @@ -192,7 +198,7 @@ export class PriorityQueue { * starting at the first element and moving down the heap. * @private */ - _siftDown() { + private _siftDown(): void { let node = 0; while ( (this._left(node) < this.size && this._greater(this._left(node), node)) || @@ -211,7 +217,7 @@ export class PriorityQueue { * the index can be computed without needing to traverse the heap. * @private */ - _smallest() { + private _smallest(): number { return (2 ** (Math.floor(Math.log2(this.size))) - 1); } } @@ -220,6 +226,8 @@ export class PriorityQueue { * A trie structure to efficiently store and search for strings. */ export class CharTrie { + private root: CharTrieNode; + constructor() { this.root = CharTrieNode.default(); } @@ -228,7 +236,7 @@ export class CharTrie { * Adds one or more `texts` to the trie. * @param {string[]} texts The strings to add to the trie. */ - extend(texts) { + extend(texts: string[]): void { for (const text of texts) { this.push(text); } @@ -238,7 +246,7 @@ export class CharTrie { * Adds text to the trie. * @param {string} text The string to add to the trie. */ - push(text) { + push(text: string): void { let node = this.root; for (const ch of text) { let child = node.children.get(ch); @@ -256,7 +264,7 @@ export class CharTrie { * @param {string} text The common prefix to search for. * @yields {string} Each string in the trie that has `text` as a prefix. */ - *commonPrefixSearch(text) { + *commonPrefixSearch(text: string): Generator { let node = this.root; if (node === undefined) return; @@ -276,12 +284,15 @@ export class CharTrie { * Represents a node in a character trie. */ class CharTrieNode { + public isLeaf: boolean; + public children: Map; + /** * Create a new CharTrieNode. * @param {boolean} isLeaf Whether the node is a leaf node or not. * @param {Map} children A map containing the node's children, where the key is a character and the value is a `CharTrieNode`. */ - constructor(isLeaf, children) { + constructor(isLeaf: boolean, children: Map) { this.isLeaf = isLeaf; this.children = children; } @@ -290,7 +301,7 @@ class CharTrieNode { * Returns a new `CharTrieNode` instance with default values. * @returns {CharTrieNode} A new `CharTrieNode` instance with `isLeaf` set to `false` and an empty `children` map. */ - static default() { + static default(): CharTrieNode { return new CharTrieNode(false, new Map()); } } @@ -299,6 +310,14 @@ class CharTrieNode { * A lattice data structure to be used for tokenization. */ export class TokenLattice { + private chars: string[]; + private len: number; + private bosTokenId: number; + private eosTokenId: number; + private nodes: TokenLatticeNode[]; + private beginNodes: TokenLatticeNode[][]; + private endNodes: TokenLatticeNode[][]; + /** * Creates a new TokenLattice instance. * @@ -306,7 +325,7 @@ export class TokenLattice { * @param {number} bosTokenId The beginning-of-sequence token ID. * @param {number} eosTokenId The end-of-sequence token ID. */ - constructor(sentence, bosTokenId, eosTokenId) { + constructor(sentence: string, bosTokenId: number, eosTokenId: number) { this.chars = Array.from(sentence); this.len = this.chars.length; this.bosTokenId = bosTokenId; @@ -331,7 +350,7 @@ export class TokenLattice { * @param {number} score The score of the token. * @param {number} tokenId The token ID of the token. */ - insert(pos, length, score, tokenId) { + insert(pos: number, length: number, score: number, tokenId: number): void { const nodeId = this.nodes.length; const node = new TokenLatticeNode(tokenId, nodeId, pos, length, score); this.beginNodes[pos].push(node); @@ -344,7 +363,7 @@ export class TokenLattice { * * @returns {TokenLatticeNode[]} The most likely sequence of tokens. */ - viterbi() { + viterbi(): TokenLatticeNode[] { const len = this.len; let pos = 0; while (pos <= len) { @@ -373,7 +392,7 @@ export class TokenLattice { ++pos; } - const results = []; + const results: TokenLatticeNode[] = []; const root = this.beginNodes[len][0]; const prev = root.prev; if (prev === null) { @@ -395,14 +414,14 @@ export class TokenLattice { * @param {TokenLatticeNode} node * @returns {string} The array of nodes representing the most likely sequence of tokens. */ - piece(node) { + piece(node: TokenLatticeNode): string { return this.chars.slice(node.pos, node.pos + node.length).join(''); } /** * @returns {string[]} The most likely sequence of tokens. */ - tokens() { + tokens(): string[] { const nodes = this.viterbi(); return nodes.map(x => this.piece(x)); } @@ -410,12 +429,24 @@ export class TokenLattice { /** * @returns {number[]} The most likely sequence of token ids. */ - tokenIds() { + tokenIds(): number[] { const nodes = this.viterbi(); return nodes.map(x => x.tokenId); } } + +/** + * Represents a node in a token lattice. + */ class TokenLatticeNode { + public tokenId: number; + public nodeId: number; + public pos: number; + public length: number; + public score: number; + public prev: TokenLatticeNode | null; + public backtraceScore: number; + /** * Represents a node in a token lattice for a given sentence. * @param {number} tokenId The ID of the token associated with this node. @@ -424,7 +455,7 @@ class TokenLatticeNode { * @param {number} length The length of the token. * @param {number} score The score associated with the token. */ - constructor(tokenId, nodeId, pos, length, score) { + constructor(tokenId: number, nodeId: number, pos: number, length: number, score: number) { this.tokenId = tokenId; this.nodeId = nodeId; this.pos = pos; @@ -438,7 +469,7 @@ class TokenLatticeNode { * Returns a clone of this node. * @returns {TokenLatticeNode} A clone of this node. */ - clone() { + clone(): TokenLatticeNode { const n = new TokenLatticeNode(this.tokenId, this.nodeId, this.pos, this.length, this.score); n.prev = this.prev; n.backtraceScore = this.backtraceScore; diff --git a/src/utils/devices.js b/src/utils/devices.ts similarity index 90% rename from src/utils/devices.js rename to src/utils/devices.ts index 1086b33e4..0340bbf12 100644 --- a/src/utils/devices.js +++ b/src/utils/devices.ts @@ -15,8 +15,9 @@ export const DEVICE_TYPES = Object.freeze({ 'webnn-npu': 'webnn-npu', // WebNN NPU 'webnn-gpu': 'webnn-gpu', // WebNN GPU 'webnn-cpu': 'webnn-cpu', // WebNN CPU -}); +} as const); /** * @typedef {keyof typeof DEVICE_TYPES} DeviceType */ +export type DeviceType = keyof typeof DEVICE_TYPES; \ No newline at end of file diff --git a/src/utils/dtypes.js b/src/utils/dtypes.ts similarity index 97% rename from src/utils/dtypes.js rename to src/utils/dtypes.ts index 845eef5e0..942b4a67f 100644 --- a/src/utils/dtypes.js +++ b/src/utils/dtypes.ts @@ -44,6 +44,7 @@ export const DATA_TYPES = Object.freeze({ q4f16: 'q4f16', // fp16 model with int4 block weight quantization }); /** @typedef {keyof typeof DATA_TYPES} DataType */ +export type DataType = keyof typeof DATA_TYPES; export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({ // NOTE: If not specified, will default to fp32 diff --git a/src/utils/generic.js b/src/utils/generic.ts similarity index 56% rename from src/utils/generic.js rename to src/utils/generic.ts index 5ccd467ad..3c09a5ebe 100644 --- a/src/utils/generic.js +++ b/src/utils/generic.ts @@ -1,3 +1,15 @@ +/** + * Type definition for a callable function that can be instantiated + */ +export type CallableFunction = { + (...args: any[]): any; + _call(...args: any[]): any; +} + +/** + * Type definition for the callable constructor + */ +export type CallableConstructor = new () => CallableFunction; /** * A base class for creating callable objects. @@ -5,7 +17,7 @@ * * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}} */ -export const Callable = /** @type {any} */ (class { +export const Callable: CallableConstructor = class { /** * Creates a new instance of the Callable class. */ @@ -16,10 +28,14 @@ export const Callable = /** @type {any} */ (class { * @param {...any} args Zero or more arguments to pass to the '_call' method. * @returns {*} The result of calling the '_call' method. */ - let closure = function (...args) { - return closure._call(...args) - } - return Object.setPrototypeOf(closure, new.target.prototype) + const closure: CallableFunction = Object.assign( + function (this: CallableFunction, ...args: any[]): any { + return this._call(...args); + }, + { _call: this._call } + ); + + return Object.setPrototypeOf(closure, new.target.prototype); } /** @@ -29,7 +45,7 @@ export const Callable = /** @type {any} */ (class { * @param {any[]} args * @throws {Error} If the subclass does not implement the `_call` method. */ - _call(...args) { - throw Error('Must implement _call method in subclass') + _call(...args: any[]): any { + throw Error('Must implement _call method in subclass'); } -}); +} as unknown as CallableConstructor; diff --git a/src/utils/hub.js b/src/utils/hub.ts similarity index 86% rename from src/utils/hub.js rename to src/utils/hub.ts index 17ee4c1b1..3d81e6db5 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.ts @@ -1,4 +1,3 @@ - /** * @file Utility functions to interact with the Hugging Face Hub (https://huggingface.co/models) * @@ -8,13 +7,22 @@ import fs from 'fs'; import path from 'path'; -import { env } from '../env.js'; -import { dispatchCallback } from './core.js'; +import { env } from '../env'; +import { dispatchCallback } from './core'; +import { type DeviceType } from './devices'; +import { type DataType } from './dtypes'; +import { type ProgressCallback } from './core'; +import { type PretrainedConfig } from '../configs'; +import { type InferenceSession } from 'onnxruntime-common'; +export interface ICache { + match(request: string): Promise; + put(request: string, response: Response | FileResponse): Promise; +} /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. - * @property {import('./core.js').ProgressCallback} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. - * @property {import('../configs.js').PretrainedConfig} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: + * @property {import('./core').ProgressCallback} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. + * @property {import('../configs').PretrainedConfig} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: * - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). * - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. * @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. @@ -23,6 +31,13 @@ import { dispatchCallback } from './core.js'; * since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. * NOTE: This setting is ignored for local requests. */ +export interface PretrainedOptions { + progress_callback?: ProgressCallback; + config?: PretrainedConfig; + cache_dir?: string; + local_files_only?: boolean; + revision?: string; +} /** * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. @@ -34,15 +49,24 @@ import { dispatchCallback } from './core.js'; * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). * @property {import('onnxruntime-common').InferenceSession.SessionOptions} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */ +export interface ModelSpecificPretrainedOptions { + subfolder: string; + model_file_name?: string | null; + device?: DeviceType | Record | null; + dtype?: DataType | Record | null; + use_external_data_format?: boolean | Record | false; + session_options?: InferenceSession.SessionOptions; +} /** * @typedef {PretrainedOptions & ModelSpecificPretrainedOptions} PretrainedModelOptions Options for loading a pretrained model. */ +export type PretrainedModelOptions = PretrainedOptions & ModelSpecificPretrainedOptions; /** * Mapping from file extensions to MIME types. */ -const CONTENT_TYPE_MAP = { +const CONTENT_TYPE_MAP: Record = { 'txt': 'text/plain', 'html': 'text/html', 'css': 'text/css', @@ -52,23 +76,30 @@ const CONTENT_TYPE_MAP = { 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'gif': 'image/gif', -} +} as const; + class FileResponse { + private filePath: string; + public headers: Headers; + public exists: boolean; + public status: number; + public statusText: string; + public body: ReadableStream | null; /** * Creates a new `FileResponse` object. * @param {string|URL} filePath */ - constructor(filePath) { - this.filePath = filePath; + constructor(filePath: string | URL) { + this.filePath = filePath.toString(); this.headers = new Headers(); - this.exists = fs.existsSync(filePath); + this.exists = fs.existsSync(this.filePath); if (this.exists) { this.status = 200; this.statusText = 'OK'; - let stats = fs.statSync(filePath); + let stats = fs.statSync(this.filePath); this.headers.set('content-length', stats.size.toString()); this.updateContentType(); @@ -94,9 +125,9 @@ class FileResponse { * the file specified by the filePath property of the current object. * @returns {void} */ - updateContentType() { + private updateContentType(): void { // Set content-type header based on file extension - const extension = this.filePath.toString().split('.').pop().toLowerCase(); + const extension = this.filePath.toString().split('.').pop()?.toLowerCase() ?? ''; this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); } @@ -104,7 +135,7 @@ class FileResponse { * Clone the current FileResponse object. * @returns {FileResponse} A new FileResponse object with the same properties as the current object. */ - clone() { + clone(): FileResponse { let response = new FileResponse(this.filePath); response.exists = this.exists; response.status = this.status; @@ -119,9 +150,9 @@ class FileResponse { * @returns {Promise} A Promise that resolves with an ArrayBuffer containing the file's contents. * @throws {Error} If the file cannot be read. */ - async arrayBuffer() { + async arrayBuffer(): Promise { const data = await fs.promises.readFile(this.filePath); - return /** @type {ArrayBuffer} */ (data.buffer); + return data.buffer; } /** @@ -130,7 +161,7 @@ class FileResponse { * @returns {Promise} A Promise that resolves with a Blob containing the file's contents. * @throws {Error} If the file cannot be read. */ - async blob() { + async blob(): Promise { const data = await fs.promises.readFile(this.filePath); return new Blob([data], { type: this.headers.get('content-type') }); } @@ -141,7 +172,7 @@ class FileResponse { * @returns {Promise} A Promise that resolves with a string containing the file's contents. * @throws {Error} If the file cannot be read. */ - async text() { + async text(): Promise { const data = await fs.promises.readFile(this.filePath, 'utf8'); return data; } @@ -153,7 +184,7 @@ class FileResponse { * @returns {Promise} A Promise that resolves with a parsed JavaScript object containing the file's contents. * @throws {Error} If the file cannot be read. */ - async json() { + async json(): Promise { return JSON.parse(await this.text()); } } @@ -165,8 +196,8 @@ class FileResponse { * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. * @returns {boolean} True if the string is a valid URL, false otherwise. */ -function isValidUrl(string, protocols = null, validHosts = null) { - let url; +function isValidUrl(string: string | URL, protocols: string[] | null = null, validHosts: string[] | null = null): boolean { + let url: URL; try { url = new URL(string); } catch (_) { @@ -187,8 +218,7 @@ function isValidUrl(string, protocols = null, validHosts = null) { * @param {URL|string} urlOrPath The URL/path of the file to get. * @returns {Promise} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API). */ -export async function getFile(urlOrPath) { - +export async function getFile(urlOrPath: URL | string): Promise { if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) { return new FileResponse(urlOrPath); @@ -210,16 +240,16 @@ export async function getFile(urlOrPath) { headers.set('Authorization', `Bearer ${token}`); } } - return fetch(urlOrPath, { headers }); + return fetch(urlOrPath.toString(), { headers }); } else { // Running in a browser-environment, so we use default headers // NOTE: We do not allow passing authorization headers in the browser, // since this would require exposing the token to the client. - return fetch(urlOrPath); + return fetch(urlOrPath.toString()); } } -const ERROR_MAPPING = { +const ERROR_MAPPING: Record = { // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) 400: 'Bad request error occurred while trying to load file', 401: 'Unauthorized access to file', @@ -232,7 +262,8 @@ const ERROR_MAPPING = { 502: 'Bad gateway error occurred while trying to load file', 503: 'Service unavailable error occurred while trying to load file', 504: 'Gateway timeout error occurred while trying to load file', -} +} as const; + /** * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. * @param {number} status The HTTP status code of the error. @@ -241,7 +272,7 @@ const ERROR_MAPPING = { * @returns {null} Returns `null` if `fatal = true`. * @throws {Error} If `fatal = false`. */ -function handleError(status, remoteURL, fatal) { +function handleError(status: number, remoteURL: string, fatal: boolean): null { if (!fatal) { // File was not loaded correctly, but it is optional. // TODO in future, cache the response? @@ -253,11 +284,13 @@ function handleError(status, remoteURL, fatal) { } class FileCache { + private path: string; + /** * Instantiate a `FileCache` object. * @param {string} path */ - constructor(path) { + constructor(path: string) { this.path = path; } @@ -266,8 +299,7 @@ class FileCache { * @param {string} request * @returns {Promise} */ - async match(request) { - + async match(request: string): Promise { let filePath = path.join(this.path, request); let file = new FileResponse(filePath); @@ -284,7 +316,7 @@ class FileCache { * @param {Response|FileResponse} response * @returns {Promise} */ - async put(request, response) { + async put(request: string, response: Response | FileResponse): Promise { const buffer = Buffer.from(await response.arrayBuffer()); let outputPath = path.join(this.path, request); @@ -307,12 +339,12 @@ class FileCache { } /** - * + * Helper function to try to get a file from cache. * @param {FileCache|Cache} cache The cache to search * @param {string[]} names The names of the item to search for * @returns {Promise} The item from the cache, or undefined if not found. */ -async function tryCache(cache, ...names) { +async function tryCache(cache: ICache, ...names: string[]): Promise { for (let name of names) { try { let result = await cache.match(name); @@ -337,9 +369,14 @@ async function tryCache(cache, ...names) { * @param {PretrainedOptions} [options] An object containing optional parameters. * * @throws Will throw an error if the file is not found and `fatal` is true. - * @returns {Promise} A Promise that resolves with the file content as a buffer. + * @returns {Promise} A Promise that resolves with the file content as a buffer. */ -export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}) { +export async function getModelFile( + path_or_repo_id: string, + filename: string, + fatal: boolean = true, + options: PretrainedOptions = {} +): Promise { if (!env.allowLocalModels) { // User has disabled local models, so we just make sure other settings are correct. @@ -360,7 +397,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // First, check if the a caching backend is available // If no caching mechanism available, will download the file every time - let cache; + let cache: ICache | undefined; if (!cache && env.useBrowserCache) { if (typeof caches === 'undefined') { throw Error('Browser cache is not available in this environment.') @@ -418,15 +455,13 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // If a specific revision is requested, we account for this in the cache key. let fsCacheKey = revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename); - /** @type {string} */ - let cacheKey; + let cacheKey: string; let proposedCacheKey = cache instanceof FileCache ? fsCacheKey : remoteURL; // Whether to cache the final response in the end. let toCacheResponse = false; - /** @type {Response|FileResponse|undefined} */ - let response; + let response: Response | FileResponse | undefined; if (cache) { // A caching system is available, so we try to get the file from it. @@ -504,8 +539,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti file: filename }) - /** @type {Uint8Array} */ - let buffer; + let buffer: Uint8Array; if (!options.progress_callback) { // If no progress callback is specified, we can use the `.arrayBuffer()` @@ -580,7 +614,12 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti * @returns {Promise} The JSON data parsed into a JavaScript object. * @throws Will throw an error if the file is not found and `fatal` is true. */ -export async function getModelJSON(modelPath, fileName, fatal = true, options = {}) { +export async function getModelJSON( + modelPath: string, + fileName: string, + fatal: boolean = true, + options: PretrainedOptions = {} +): Promise { let buffer = await getModelFile(modelPath, fileName, fatal, options); if (buffer === null) { // Return empty object @@ -592,6 +631,7 @@ export async function getModelJSON(modelPath, fileName, fatal = true, options = return JSON.parse(jsonData); } + /** * Read and track progress when reading a Response object * @@ -599,7 +639,10 @@ export async function getModelJSON(modelPath, fileName, fatal = true, options = * @param {(data: {progress: number, loaded: number, total: number}) => void} progress_callback The function to call with progress updates * @returns {Promise} A Promise that resolves with the Uint8Array buffer */ -async function readResponse(response, progress_callback) { +async function readResponse( + response: Response | FileResponse, + progress_callback: (data: { progress: number, loaded: number, total: number }) => void +): Promise { const contentLength = response.headers.get('Content-Length'); if (contentLength === null) { @@ -610,7 +653,7 @@ async function readResponse(response, progress_callback) { let loaded = 0; const reader = response.body.getReader(); - async function read() { + async function read(): Promise { const { done, value } = await reader.read(); if (done) return; @@ -654,7 +697,7 @@ async function readResponse(response, progress_callback) { * @param {...string} parts Multiple parts of a path. * @returns {string} A string representing the joined path. */ -function pathJoin(...parts) { +function pathJoin(...parts: string[]): string { // https://stackoverflow.com/a/55142565 parts = parts.map((part, index) => { if (index) { diff --git a/src/utils/image.js b/src/utils/image.ts similarity index 94% rename from src/utils/image.js rename to src/utils/image.ts index 40f51625e..04d5d55e5 100644 --- a/src/utils/image.js +++ b/src/utils/image.ts @@ -1,4 +1,3 @@ - /** * @file Helper module for image processing. * @@ -8,21 +7,23 @@ * @module utils/image */ -import { isNullishDimension, saveBlob } from './core.js'; -import { getFile } from './hub.js'; -import { apis } from '../env.js'; -import { Tensor } from './tensor.js'; +import { isNullishDimension, saveBlob } from './core'; +import { getFile } from './hub'; +import { apis } from '../env'; +import { Tensor } from './tensor'; // Will be empty (or not used) if running in browser or web-worker import sharp from 'sharp'; +export type ImageChannels = 1 | 2 | 3 | 4; + let createCanvasFunction; let ImageDataClass; let loadImageFunction; const IS_BROWSER_OR_WEBWORKER = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; if (IS_BROWSER_OR_WEBWORKER) { // Running in browser or web-worker - createCanvasFunction = (/** @type {number} */ width, /** @type {number} */ height) => { + createCanvasFunction = (width: number, height: number) => { if (!self.OffscreenCanvas) { throw new Error('OffscreenCanvas not supported by this browser.'); } @@ -34,7 +35,7 @@ if (IS_BROWSER_OR_WEBWORKER) { } else if (sharp) { // Running in Node.js, electron, or other non-browser environment - loadImageFunction = async (/**@type {sharp.Sharp}*/img) => { + loadImageFunction = async (img: sharp.Sharp) => { const metadata = await img.metadata(); const rawChannels = metadata.channels; @@ -75,6 +76,10 @@ const CONTENT_TYPE_MAP = new Map([ ]); export class RawImage { + data: Uint8ClampedArray | Uint8Array; + width: number; + height: number; + channels: ImageChannels; /** * Create a new `RawImage` object. @@ -83,7 +88,7 @@ export class RawImage { * @param {number} height The height of the image. * @param {1|2|3|4} channels The number of channels. */ - constructor(data, width, height, channels) { + constructor(data: Uint8ClampedArray | Uint8Array, width: number, height: number, channels: ImageChannels) { this.data = data; this.width = width; this.height = height; @@ -94,7 +99,7 @@ export class RawImage { * Returns the size of the image (width, height). * @returns {[number, number]} The size of the image (width, height). */ - get size() { + get size(): [number, number] { return [this.width, this.height]; } @@ -114,7 +119,7 @@ export class RawImage { * // } * ``` */ - static async read(input) { + static async read(input: RawImage | string | URL) { if (input instanceof RawImage) { return input; } else if (typeof input === 'string' || input instanceof URL) { @@ -129,12 +134,12 @@ export class RawImage { * @param {HTMLCanvasElement|OffscreenCanvas} canvas The canvas to read the image from. * @returns {RawImage} The image object. */ - static fromCanvas(canvas) { + static fromCanvas(canvas: HTMLCanvasElement | OffscreenCanvas): RawImage { if (!IS_BROWSER_OR_WEBWORKER) { throw new Error('fromCanvas() is only supported in browser environments.') } - const ctx = canvas.getContext('2d'); + const ctx = canvas.getContext('2d') as CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D; const data = ctx.getImageData(0, 0, canvas.width, canvas.height).data; return new RawImage(data, canvas.width, canvas.height, 4); } @@ -144,7 +149,7 @@ export class RawImage { * @param {string|URL} url The URL or file path to read the image from. * @returns {Promise} The image object. */ - static async fromURL(url) { + static async fromURL(url: string | URL): Promise { const response = await getFile(url); if (response.status !== 200) { throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`); @@ -158,7 +163,7 @@ export class RawImage { * @param {Blob} blob The blob to read the image from. * @returns {Promise} The image object. */ - static async fromBlob(blob) { + static async fromBlob(blob: Blob): Promise { if (IS_BROWSER_OR_WEBWORKER) { // Running in environment with canvas const img = await loadImageFunction(blob); @@ -182,7 +187,7 @@ export class RawImage { * Helper method to create a new Image from a tensor * @param {Tensor} tensor */ - static fromTensor(tensor, channel_format = 'CHW') { + static fromTensor(tensor: Tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { throw new Error(`Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.`); } @@ -212,7 +217,7 @@ export class RawImage { * Convert the image to grayscale format. * @returns {RawImage} `this` to support chaining. */ - grayscale() { + grayscale(): RawImage { if (this.channels === 1) { return this; } @@ -239,7 +244,7 @@ export class RawImage { * Convert the image to RGB format. * @returns {RawImage} `this` to support chaining. */ - rgb() { + rgb(): RawImage { if (this.channels === 3) { return this; } @@ -272,7 +277,7 @@ export class RawImage { * Convert the image to RGBA format. * @returns {RawImage} `this` to support chaining. */ - rgba() { + rgba(): RawImage { if (this.channels === 4) { return this; } @@ -311,7 +316,7 @@ export class RawImage { * @throws {Error} If the image does not have 4 channels. * @throws {Error} If the mask is not a single channel. */ - putAlpha(mask) { + putAlpha(mask: RawImage): RawImage { if (mask.width !== this.width || mask.height !== this.height) { throw new Error(`Expected mask size to be ${this.width}x${this.height}, but got ${mask.width}x${mask.height}`); } @@ -351,9 +356,9 @@ export class RawImage { * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. * @returns {Promise} `this` to support chaining. */ - async resize(width, height, { + async resize(width: number, height: number, { resample = 2, - } = {}) { + }: { resample?: 0 | 1 | 2 | 3 | 4 | 5 | string; } = {}): Promise { // Do nothing if the image already has the desired size if (this.width === width && this.height === height) { @@ -703,11 +708,10 @@ export class RawImage { * Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split). * @returns {RawImage[]} An array containing bands. */ - split() { + split(): RawImage[] { const { data, width, height, channels } = this; - /** @type {typeof Uint8Array | typeof Uint8ClampedArray} */ - const data_type = /** @type {any} */(data.constructor); + const data_type: typeof Uint8Array | typeof Uint8ClampedArray = data.constructor as typeof Uint8Array | typeof Uint8ClampedArray; const per_channel_length = data.length / channels; // Pre-allocate buffers for each channel @@ -734,7 +738,7 @@ export class RawImage { * @param {1|2|3|4|null} [channels] The new number of channels of the image. * @private */ - _update(data, width, height, channels = null) { + _update(data: Uint8ClampedArray, width: number, height: number, channels: ImageChannels | null = null) { this.data = data; this.width = width; this.height = height; @@ -748,7 +752,7 @@ export class RawImage { * Clone the image * @returns {RawImage} The cloned image */ - clone() { + clone(): RawImage { return new RawImage(this.data.slice(), this.width, this.height, this.channels); } @@ -757,7 +761,7 @@ export class RawImage { * @param {number} numChannels The number of channels. Must be 1, 3, or 4. * @returns {RawImage} `this` to support chaining. */ - convert(numChannels) { + convert(numChannels: number): RawImage { if (this.channels === numChannels) return this; // Already correct number of channels switch (numChannels) { @@ -780,7 +784,7 @@ export class RawImage { * Save the image to the given path. * @param {string} path The path to save the image to. */ - async save(path) { + async save(path: string) { if (IS_BROWSER_OR_WEBWORKER) { if (apis.IS_WEBWORKER_ENV) { diff --git a/src/utils/maths.js b/src/utils/maths.ts similarity index 88% rename from src/utils/maths.js rename to src/utils/maths.ts index 107068dca..da481665d 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.ts @@ -1,4 +1,3 @@ - /** * @file Helper module for mathematical processing. * @@ -14,10 +13,26 @@ * @typedef {TypedArray | BigTypedArray} AnyTypedArray */ +export type TypedArray = Int8Array | Uint8Array | Uint8ClampedArray | Int16Array | Uint16Array | Int32Array | Uint32Array | Float32Array | Float64Array; +export type BigTypedArray = BigInt64Array | BigUint64Array; +export type AnyTypedArray = TypedArray | BigTypedArray; + + +// convert to original type +export function toOriginalArrType(originalArr: T, newArr: number[]): T { + return newArr as T; + if (Array.isArray(originalArr)) { + return newArr as T; + } else { + // @ts-ignore + return new (originalArr.constructor)(newArr) as T; + } +} + /** * @param {TypedArray} input */ -export function interpolate_data(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) { +export function interpolate_data(input: TypedArray, [in_channels, in_height, in_width]: [number, number, number], [out_height, out_width]: [number, number], mode = 'bilinear', align_corners = false): TypedArray { // TODO use mode and align_corners // Output image dimensions @@ -95,7 +110,7 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out * @param {number[]} axes * @returns {[T, number[]]} The permuted array and the new shape. */ -export function permute_data(array, dims, axes) { +export function permute_data(array: T, dims: number[], axes: number[]): [T, number[]] { // Calculate the new shape of the permuted array // and the stride of the original array const shape = new Array(axes.length); @@ -134,21 +149,20 @@ export function permute_data(array, dims, axes) { * @param {T} arr The array of numbers to compute the softmax of. * @returns {T} The softmax array. */ -export function softmax(arr) { +export function softmax(arr: T): T { // Compute the maximum value in the array const maxVal = max(arr)[0]; // Compute the exponentials of the array values - const exps = arr.map(x => Math.exp(x - maxVal)); + const exps = Array.from(arr, x => Math.exp(x - maxVal)); // Compute the sum of the exponentials - // @ts-ignore const sumExps = exps.reduce((acc, val) => acc + val, 0); // Compute the softmax values const softmaxArr = exps.map(x => x / sumExps); - return /** @type {T} */(softmaxArr); + return toOriginalArrType(arr, softmaxArr); } /** @@ -157,7 +171,7 @@ export function softmax(arr) { * @param {T} arr The input array to calculate the log_softmax function for. * @returns {T} The resulting log_softmax array. */ -export function log_softmax(arr) { +export function log_softmax(arr: T): T { // Compute the maximum value in the array const maxVal = max(arr)[0]; @@ -171,9 +185,9 @@ export function log_softmax(arr) { const logSum = Math.log(sumExps); // Compute the softmax values - const logSoftmaxArr = arr.map(x => x - maxVal - logSum); + const logSoftmaxArr = Array.from(arr, x => x - maxVal - logSum); - return /** @type {T} */(logSoftmaxArr); + return toOriginalArrType(arr, logSoftmaxArr); } /** @@ -182,7 +196,7 @@ export function log_softmax(arr) { * @param {number[]} arr2 The second array. * @returns {number} The dot product of arr1 and arr2. */ -export function dot(arr1, arr2) { +export function dot(arr1: number[], arr2: number[]): number { let result = 0; for (let i = 0; i < arr1.length; ++i) { result += arr1[i] * arr2[i]; @@ -197,7 +211,7 @@ export function dot(arr1, arr2) { * @param {number[]} arr2 The second array. * @returns {number} The cosine similarity between the two arrays. */ -export function cos_sim(arr1, arr2) { +export function cos_sim(arr1: number[], arr2: number[]): number { // Calculate dot product of the two arrays const dotProduct = dot(arr1, arr2); @@ -218,7 +232,7 @@ export function cos_sim(arr1, arr2) { * @param {number[]} arr The array to calculate the magnitude of. * @returns {number} The magnitude of the array. */ -export function magnitude(arr) { +export function magnitude(arr: number[]): number { return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0)); } @@ -230,7 +244,7 @@ export function magnitude(arr) { * @returns {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] * @throws {Error} If array is empty. */ -export function min(arr) { +export function min(arr: T): T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number] { if (arr.length === 0) throw Error('Array must not be empty'); let min = arr[0]; let indexOfMin = 0; @@ -240,7 +254,8 @@ export function min(arr) { indexOfMin = i; } } - return /** @type {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} */([min, indexOfMin]); + // @ts-ignore - We know this is safe due to the type constraint and return type + return [min, indexOfMin]; } @@ -251,7 +266,7 @@ export function min(arr) { * @returns {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ -export function max(arr) { +export function max(arr: T): T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number] { if (arr.length === 0) throw Error('Array must not be empty'); let max = arr[0]; let indexOfMax = 0; @@ -261,10 +276,11 @@ export function max(arr) { indexOfMax = i; } } - return /** @type {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} */([max, indexOfMax]); + // @ts-ignore - We know this is safe due to the type constraint and return type + return [max, indexOfMax]; } -function isPowerOfTwo(number) { +function isPowerOfTwo(number: number): boolean { // Check if the number is greater than 0 and has only one bit set to 1 return (number > 0) && ((number & (number - 1)) === 0); } @@ -277,11 +293,17 @@ function isPowerOfTwo(number) { * Code adapted from https://www.npmjs.com/package/fft.js */ class P2FFT { + private size: number; + private _csize: number; + private table: Float64Array; + private _width: number; + private _bitrev: Int32Array; + /** * @param {number} size The size of the input array. Must be a power of two larger than 1. * @throws {Error} FFT size must be a power of two larger than 1. */ - constructor(size) { + constructor(size: number) { this.size = size | 0; // convert to a 32-bit signed integer if (this.size <= 1 || !isPowerOfTwo(this.size)) throw new Error('FFT size must be a power of two larger than 1'); @@ -321,7 +343,7 @@ class P2FFT { * * @returns {Float64Array} A complex number array with size `2 * size` */ - createComplexArray() { + createComplexArray(): Float64Array { return new Float64Array(this._csize); } @@ -332,7 +354,7 @@ class P2FFT { * @param {number[]} [storage] An optional array to store the result in. * @returns {number[]} An array of real numbers representing the input complex number representation. */ - fromComplexArray(complex, storage) { + fromComplexArray(complex: Float64Array, storage: number[]): number[] { const res = storage || new Array(complex.length >>> 1); for (let i = 0; i < complex.length; i += 2) res[i >>> 1] = complex[i]; @@ -345,7 +367,7 @@ class P2FFT { * @param {Float64Array} [storage] Optional buffer to store the output array. * @returns {Float64Array} The complex-valued output array. */ - toComplexArray(input, storage) { + toComplexArray(input: Float64Array, storage: Float64Array): Float64Array { const res = storage || this.createComplexArray(); for (let i = 0; i < res.length; i += 2) { res[i] = input[i >>> 1]; @@ -364,7 +386,7 @@ class P2FFT { * * @returns {void} */ - transform(out, data) { + transform(out: Float64Array, data: Float64Array): void { if (out === data) throw new Error('Input and output buffers must be different'); @@ -381,7 +403,7 @@ class P2FFT { * * @throws {Error} If the input and output buffers are the same. */ - realTransform(out, data) { + realTransform(out: Float64Array, data: Float64Array) { if (out === data) throw new Error('Input and output buffers must be different'); @@ -398,7 +420,7 @@ class P2FFT { * @throws {Error} If `out` and `data` refer to the same buffer. * @returns {void} */ - inverseTransform(out, data) { + inverseTransform(out: Float64Array, data: Float64Array): void { if (out === data) throw new Error('Input and output buffers must be different'); @@ -415,7 +437,7 @@ class P2FFT { * @param {number} inv A scaling factor to apply to the transform. * @returns {void} */ - _transform4(out, data, inv) { + _transform4(out: Float64Array, data: Float64Array, inv: number): void { // radix-4 implementation const size = this._csize; @@ -425,8 +447,8 @@ class P2FFT { let step = 1 << width; let len = (size / step) << 1; - let outOff; - let t; + let outOff: number; + let t: number; const bitrev = this._bitrev; if (len === 4) { for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) { @@ -516,7 +538,7 @@ class P2FFT { * @param {number} step The step size for indexing the input data. * @returns {void} */ - _singleTransform2(data, out, outOff, off, step) { + _singleTransform2(data: Float64Array, out: Float64Array, outOff: number, off: number, step: number): void { // radix-2 implementation // NOTE: Only called for len=4 @@ -543,7 +565,7 @@ class P2FFT { * * @returns {void} */ - _singleTransform4(data, out, outOff, off, step, inv) { + _singleTransform4(data: Float64Array, out: Float64Array, outOff: number, off: number, step: number, inv: number): void { // radix-4 // NOTE: Only called for len=8 const step2 = step * 2; @@ -586,7 +608,7 @@ class P2FFT { * @param {Float64Array} data Input array of real data to be transformed * @param {number} inv The scale factor used to normalize the inverse transform */ - _realTransform4(out, data, inv) { + _realTransform4(out: Float64Array, data: Float64Array, inv: number) { // Real input radix-4 implementation const size = this._csize; @@ -595,8 +617,8 @@ class P2FFT { let step = 1 << width; let len = (size / step) << 1; - let outOff; - let t; + let outOff: number; + let t: number; const bitrev = this._bitrev; if (len === 4) { for (outOff = 0, t = 0; outOff < size; outOff += len, ++t) { @@ -713,7 +735,7 @@ class P2FFT { * * @returns {void} */ - _singleRealTransform2(data, out, outOff, off, step) { + _singleRealTransform2(data: Float64Array, out: Float64Array, outOff: number, off: number, step: number): void { // radix-2 implementation // NOTE: Only called for len=4 @@ -737,7 +759,7 @@ class P2FFT { * @param {number} step The step size for the input array. * @param {number} inv The value of inverse. */ - _singleRealTransform4(data, out, outOff, off, step, inv) { + _singleRealTransform4(data: Float64Array, out: Float64Array, outOff: number, off: number, step: number, inv: number) { // radix-4 // NOTE: Only called for len=8 const step2 = step * 2; @@ -774,12 +796,21 @@ class P2FFT { * For more information, see: https://math.stackexchange.com/questions/77118/non-power-of-2-ffts/77156#77156 */ class NP2FFT { + public bufferSize: number; + private _a: number; + private _chirpBuffer: Float64Array; + private _buffer1: Float64Array; + private _buffer2: Float64Array; + private _outBuffer1: Float64Array; + private _outBuffer2: Float64Array; + private _slicedChirpBuffer: Float64Array; + private _f: P2FFT; /** * Constructs a new NP2FFT object. * @param {number} fft_length The length of the FFT */ - constructor(fft_length) { + constructor(fft_length: number) { // Helper variables const a = 2 * (fft_length - 1); const b = 2 * (2 * fft_length - 1); @@ -829,7 +860,7 @@ class NP2FFT { this._f.transform(this._chirpBuffer, ichirp); } - _transform(output, input, real) { + _transform(output: number[], input: number[], real: boolean) { const ib1 = this._buffer1; const ib2 = this._buffer2; const ob2 = this._outBuffer1; @@ -887,7 +918,12 @@ class NP2FFT { } export class FFT { - constructor(fft_length) { + private fft_length: number; + private isPowerOfTwo: boolean; + private fft: P2FFT | NP2FFT; + public outputBufferSize: number; + + constructor(fft_length: number) { this.fft_length = fft_length; this.isPowerOfTwo = isPowerOfTwo(fft_length); if (this.isPowerOfTwo) { @@ -914,7 +950,7 @@ export class FFT { * @param {AnyTypedArray} data The input array * @param {number} windowSize The window size */ -export function medianFilter(data, windowSize) { +export function medianFilter(data: AnyTypedArray, windowSize: number): AnyTypedArray { if (windowSize % 2 === 0 || windowSize <= 0) { throw new Error('Window size must be a positive odd number'); @@ -955,7 +991,7 @@ export function medianFilter(data, windowSize) { * @param {number} decimals The number of decimals * @returns {number} The rounded number */ -export function round(num, decimals) { +export function round(num: number, decimals: number): number { const pow = Math.pow(10, decimals); return Math.round(num * pow) / pow; } @@ -968,7 +1004,7 @@ export function round(num, decimals) { * @param {number} x The number to round * @returns {number} The rounded number */ -export function bankers_round(x) { +export function bankers_round(x: number): number { const r = Math.round(x); const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r; return br; @@ -981,7 +1017,7 @@ export function bankers_round(x) { * @param {number[][]} matrix * @returns {number[][]} */ -export function dynamic_time_warping(matrix) { +export function dynamic_time_warping(matrix: number[][]): number[][] { const output_length = matrix.length; const input_length = matrix[0].length; @@ -1004,7 +1040,7 @@ export function dynamic_time_warping(matrix) { const c1 = cost[i - 1][j]; const c2 = cost[i][j - 1]; - let c, t; + let c: number, t: number; if (c0 < c1 && c0 < c2) { c = c0; t = 0; diff --git a/src/utils/tensor.js b/src/utils/tensor.ts similarity index 74% rename from src/utils/tensor.js rename to src/utils/tensor.ts index f13d5b8af..b5fbace64 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.ts @@ -11,14 +11,16 @@ import { interpolate_data, max, min, - permute_data -} from './maths.js'; + permute_data, + AnyTypedArray, + TypedArray +} from './maths'; import { Tensor as ONNXTensor, isONNXTensor, -} from '../backends/onnx.js'; +} from '../backends/onnx'; -import { TensorOpRegistry } from '../ops/registry.js'; +import { TensorOpRegistry } from '../ops/registry'; const DataTypeMap = Object.freeze({ float32: Float32Array, @@ -36,68 +38,71 @@ const DataTypeMap = Object.freeze({ bool: Uint8Array, uint4: Uint8Array, int4: Int8Array, -}); +} as const); /** * @typedef {keyof typeof DataTypeMap} DataType - * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray + * @typedef {import('./maths').AnyTypedArray | any[]} DataArray */ - +export type DataType = keyof typeof DataTypeMap; +export type DataArray = AnyTypedArray | any[]; export class Tensor { + public ort_tensor: ONNXTensor; + /** @type {number[]} Dimensions of the tensor. */ - get dims() { + get dims(): number[] { // @ts-ignore return this.ort_tensor.dims; } - set dims(value) { + set dims(value: number[]) { // FIXME: ONNXTensor declares dims as readonly so one needs to use the constructor() if dims change. // @ts-ignore this.ort_tensor.dims = value; } /** @type {DataType} Type of the tensor. */ - get type() { + get type(): DataType { return this.ort_tensor.type; - }; + } /** @type {DataArray} The data stored in the tensor. */ - get data() { + get data(): DataArray { return this.ort_tensor.data; } - /** @type {number} The number of elements in the tensor. */ - get size() { + /** @type {number} The number of elements in the tensor. */ + get size(): number { return this.ort_tensor.size; - }; + } /** @type {string} The location of the tensor data. */ - get location() { + get location(): string { return this.ort_tensor.location; - }; + } - ort_tensor; /** * Create a new Tensor or copy an existing Tensor. * @param {[DataType, DataArray, number[]]|[ONNXTensor]} args */ - constructor(...args) { + constructor(...args: [DataType, DataArray, number[]] | [ONNXTensor]) { if (isONNXTensor(args[0])) { - this.ort_tensor = /** @type {ONNXTensor} */ (args[0]); + // @ts-ignore + this.ort_tensor = args[0]; } else { // Create new tensor this.ort_tensor = new ONNXTensor( - /** @type {DataType} */(args[0]), - /** @type {Exclude} */(args[1]), + args[0] as DataType, + args[1] as Exclude, args[2] ); } return new Proxy(this, { - get: (obj, key) => { + get: (obj: Tensor, key: string | symbol): any => { if (typeof key === 'string') { - let index = Number(key); + const index = Number(key); if (Number.isInteger(index)) { // key is an integer (i.e., index) return obj._getitem(index); @@ -106,16 +111,16 @@ export class Tensor { // @ts-ignore return obj[key]; }, - set: (obj, key, value) => { + set: (obj: Tensor, key: string | symbol, value: any): boolean => { // TODO allow setting of data - + // @ts-ignore return obj[key] = value; } }); } - dispose() { + dispose(): void { this.ort_tensor.dispose(); // this.ort_tensor = undefined; } @@ -125,7 +130,7 @@ export class Tensor { * If the tensor has more than one dimension, the iterator will yield subarrays. * @returns {Iterator} An iterator object for iterating over the tensor data in row-major order. */ - *[Symbol.iterator]() { + *[Symbol.iterator](): Iterator { const [iterLength, ...iterDims] = this.dims; if (iterDims.length > 0) { @@ -134,9 +139,8 @@ export class Tensor { yield this._subarray(i, iterSize, iterDims); } } else { - yield* this.data + yield* this.data; } - } /** @@ -144,7 +148,7 @@ export class Tensor { * @param {number} index The index to access. * @returns {Tensor} The data at the specified index. */ - _getitem(index) { + private _getitem(index: number): Tensor { const [iterLength, ...iterDims] = this.dims; index = safeIndex(index, iterLength); @@ -161,11 +165,10 @@ export class Tensor { * @param {number|bigint} item The item to search for in the tensor * @returns {number} The index of the first occurrence of item in the tensor data. */ - indexOf(item) { - const this_data = this.data; - for (let index = 0; index < this_data.length; ++index) { + indexOf(item: number | bigint): number { + for (let index = 0; index < this.data.length; ++index) { // Note: == instead of === so we can match Ints with BigInts - if (this_data[index] == item) { + if (this.data[index] == item) { return index; } } @@ -178,15 +181,14 @@ export class Tensor { * @param {any} iterDims * @returns {Tensor} */ - _subarray(index, iterSize, iterDims) { + private _subarray(index: number, iterSize: number, iterDims: number[]): Tensor { const o1 = index * iterSize; const o2 = (index + 1) * iterSize; // We use subarray if available (typed array), otherwise we use slice (normal array) - const data = - ('subarray' in this.data) - ? this.data.subarray(o1, o2) - : this.data.slice(o1, o2); + const data = 'subarray' in this.data + ? (this.data as AnyTypedArray).subarray(o1, o2) + : this.data.slice(o1, o2); return new Tensor(this.type, data, iterDims); } @@ -196,27 +198,26 @@ export class Tensor { * @returns {number|bigint} The value of this tensor as a standard JavaScript Number. * @throws {Error} If the tensor has more than one element. */ - item() { - const this_data = this.data; - if (this_data.length !== 1) { - throw new Error(`a Tensor with ${this_data.length} elements cannot be converted to Scalar`); + item(): number | bigint { + if (this.data.length !== 1) { + throw new Error(`a Tensor with ${this.data.length} elements cannot be converted to Scalar`); } - return this_data[0]; + return this.data[0]; } /** * Convert tensor data to a n-dimensional JS list * @returns {Array} */ - tolist() { - return reshape(this.data, this.dims) + tolist(): any[] { + return reshape(this.data, this.dims); } /** * Return a new Tensor with the sigmoid function applied to each element. * @returns {Tensor} The tensor with the sigmoid function applied. */ - sigmoid() { + sigmoid(): Tensor { return this.clone().sigmoid_(); } @@ -224,10 +225,9 @@ export class Tensor { * Applies the sigmoid function to the tensor in place. * @returns {Tensor} Returns `this`. */ - sigmoid_() { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] = 1 / (1 + Math.exp(-this_data[i])); + sigmoid_(): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = 1 / (1 + Math.exp(-this.data[i])); } return this; } @@ -238,7 +238,7 @@ export class Tensor { * the current element, its index, and the tensor's data array. * @returns {Tensor} A new Tensor with the callback function applied to each element. */ - map(callback) { + map(callback: (value: number, index: number, array: DataArray) => number): Tensor { return this.clone().map_(callback); } @@ -248,10 +248,9 @@ export class Tensor { * the current element, its index, and the tensor's data array. * @returns {Tensor} Returns `this`. */ - map_(callback) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] = callback(this_data[i], i, this_data); + map_(callback: (value: number, index: number, array: DataArray) => number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = callback(this.data[i], i, this.data); } return this; } @@ -261,7 +260,7 @@ export class Tensor { * @param {number} val The value to multiply by. * @returns {Tensor} The new tensor. */ - mul(val) { + mul(val: number): Tensor { return this.clone().mul_(val); } @@ -270,10 +269,9 @@ export class Tensor { * @param {number} val The value to multiply by. * @returns {Tensor} Returns `this`. */ - mul_(val) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] *= val; + mul_(val: number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] *= val; } return this; } @@ -283,7 +281,7 @@ export class Tensor { * @param {number} val The value to divide by. * @returns {Tensor} The new tensor. */ - div(val) { + div(val: number): Tensor { return this.clone().div_(val); } @@ -292,10 +290,9 @@ export class Tensor { * @param {number} val The value to divide by. * @returns {Tensor} Returns `this`. */ - div_(val) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] /= val; + div_(val: number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] /= val; } return this; } @@ -305,7 +302,7 @@ export class Tensor { * @param {number} val The value to add by. * @returns {Tensor} The new tensor. */ - add(val) { + add(val: number): Tensor { return this.clone().add_(val); } @@ -314,10 +311,9 @@ export class Tensor { * @param {number} val The value to add by. * @returns {Tensor} Returns `this`. */ - add_(val) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] += val; + add_(val: number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] += val; } return this; } @@ -327,7 +323,7 @@ export class Tensor { * @param {number} val The value to subtract by. * @returns {Tensor} The new tensor. */ - sub(val) { + sub(val: number): Tensor { return this.clone().sub_(val); } @@ -336,20 +332,22 @@ export class Tensor { * @param {number} val The value to subtract by. * @returns {Tensor} Returns `this`. */ - sub_(val) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] -= val; + sub_(val: number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] -= val; } return this; } /** * Creates a deep copy of the current Tensor. - * @returns {Tensor} A new Tensor with the same type, data, and dimensions as the original. */ - clone() { - return new Tensor(this.type, this.data.slice(), this.dims.slice()); + clone(): Tensor { + return new Tensor( + this.type, + this.data.slice(), + this.dims.slice() + ); } /** @@ -381,13 +379,10 @@ export class Tensor { * @returns {Tensor} A new Tensor containing the selected elements. * @throws {Error} If the slice input is invalid. */ - slice(...slices) { - // This allows for slicing with ranges and numbers - const newTensorDims = []; - const newOffsets = []; + slice(...slices: (number | [number, number] | null)[]): Tensor { + const newTensorDims: number[] = []; + const newOffsets: [number, number][] = []; - // slices is an array of numbers or arrays of numbers - // e.g., slices = [0, [1, 3], null, [0, 3]] for (let sliceIndex = 0; sliceIndex < this.dims.length; ++sliceIndex) { let slice = slices[sliceIndex]; @@ -398,7 +393,6 @@ export class Tensor { } else if (typeof slice === 'number') { slice = safeIndex(slice, this.dims[sliceIndex], sliceIndex); - // A number means take a single element newOffsets.push([slice, slice + 1]); @@ -416,7 +410,7 @@ export class Tensor { throw new Error(`Invalid slice: ${slice}`); } - const offsets = [ + const offsets: [number, number] = [ Math.max(start, 0), Math.min(end, this.dims[sliceIndex]) ]; @@ -432,10 +426,8 @@ export class Tensor { const newDims = newOffsets.map(([start, end]) => end - start); const newBufferSize = newDims.reduce((a, b) => a * b); - const this_data = this.data; // Allocate memory - // @ts-ignore - const data = new this_data.constructor(newBufferSize); + const data = new (this.data.constructor as new (length: number) => DataArray)(newBufferSize); // Precompute strides const stride = this.stride(); @@ -447,7 +439,7 @@ export class Tensor { originalIndex += ((num % size) + newOffsets[j][0]) * stride[j]; num = Math.floor(num / size); } - data[i] = this_data[originalIndex]; + data[i] = this.data[originalIndex]; } return new Tensor(this.type, data, newTensorDims); } @@ -457,12 +449,15 @@ export class Tensor { * @param {...number} dims Dimensions to permute. * @returns {Tensor} The permuted tensor. */ - permute(...dims) { + permute(...dims: number[]): Tensor { return permute(this, dims); } - // TODO: implement transpose. For now (backwards compatibility), it's just an alias for permute() - transpose(...dims) { + /** + * Alias for permute() + * TODO: implement transpose. For now (backwards compatibility), it's just an alias for permute() + */ + transpose(...dims: number[]): Tensor { return this.permute(...dims); } @@ -473,7 +468,7 @@ export class Tensor { * @param {boolean} keepdim Whether the output tensor has `dim` retained or not. * @returns The summed tensor */ - sum(dim = null, keepdim = false) { + sum(dim: number | null = null, keepdim: boolean = false): Tensor { return this.norm(1, dim, keepdim); } @@ -485,7 +480,7 @@ export class Tensor { * @param {boolean} [keepdim=false] Whether the output tensors have dim retained or not. * @returns {Tensor} The norm of the tensor. */ - norm(p = 'fro', dim = null, keepdim = false) { + norm(p: number | string = 'fro', dim: number | null = null, keepdim: boolean = false): Tensor { if (p === 'fro') { // NOTE: Since we only support integer dims, Frobenius norm produces the same result as p=2. p = 2; @@ -493,12 +488,11 @@ export class Tensor { throw Error(`Unsupported norm: ${p}`); } - const this_data = this.data; const fn = (a, b) => a + (b ** p); if (dim === null) { // @ts-ignore - const val = this_data.reduce(fn, 0) ** (1 / p); + const val = this.data.reduce(fn, 0) ** (1 / p); return new Tensor(this.type, [val], []); } @@ -518,14 +512,13 @@ export class Tensor { * @param {number} [dim=1] The dimension to reduce * @returns {Tensor} `this` for operation chaining. */ - normalize_(p = 2.0, dim = 1) { + normalize_(p: number = 2.0, dim: number = 1): Tensor { dim = safeIndex(dim, this.dims.length); const norm = this.norm(p, dim, true); - const this_data = this.data; const norm_data = norm.data; - for (let i = 0; i < this_data.length; ++i) { + for (let i = 0; i < this.data.length; ++i) { // Calculate the index in the resulting array let resultIndex = 0; @@ -541,7 +534,7 @@ export class Tensor { } // Divide by normalized value - this_data[i] /= norm_data[resultIndex]; + this.data[i] /= norm_data[resultIndex]; } return this; @@ -553,7 +546,7 @@ export class Tensor { * @param {number} [dim=1] The dimension to reduce * @returns {Tensor} The normalized tensor. */ - normalize(p = 2.0, dim = 1) { + normalize(p: number = 2.0, dim: number = 1): Tensor { return this.clone().normalize_(p, dim); } @@ -562,7 +555,7 @@ export class Tensor { * Stride is the jump necessary to go from one element to the next one in the specified dimension dim. * @returns {number[]} The stride of this tensor. */ - stride() { + stride(): number[] { return dimsToStride(this.dims); } @@ -572,21 +565,21 @@ export class Tensor { * NOTE: The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other. * If you would like a copy, use `tensor.clone()` before squeezing. * - * @param {number|number[]} [dim=null] If given, the input will be squeezed only in the specified dimensions. + * @param {number|number[]|null} [dim=null] If given, the input will be squeezed only in the specified dimensions. * @returns {Tensor} The squeezed tensor */ - squeeze(dim = null) { + squeeze(dim: number | number[] | null = null): Tensor { return new Tensor( this.type, this.data, calc_squeeze_dims(this.dims, dim) - ) + ); } /** - * In-place version of @see {@link Tensor.squeeze} + * In-place version of squeeze */ - squeeze_(dim = null) { + squeeze_(dim: number | number[] | null = null): Tensor { this.dims = calc_squeeze_dims(this.dims, dim); return this; } @@ -599,7 +592,7 @@ export class Tensor { * @param {number} dim The index at which to insert the singleton dimension * @returns {Tensor} The unsqueezed tensor */ - unsqueeze(dim = null) { + unsqueeze(dim: number): Tensor { return new Tensor( this.type, this.data, @@ -608,25 +601,24 @@ export class Tensor { } /** - * In-place version of @see {@link Tensor.unsqueeze} + * In-place version of unsqueeze */ - unsqueeze_(dim = null) { + unsqueeze_(dim: number): Tensor { this.dims = calc_unsqueeze_dims(this.dims, dim); return this; } /** - * In-place version of @see {@link Tensor.flatten} + * In-place version of flatten */ - flatten_(start_dim = 0, end_dim = -1) { - // TODO validate inputs + flatten_(start_dim: number = 0, end_dim: number = -1): Tensor { end_dim = (end_dim + this.dims.length) % this.dims.length; - let dimsToKeepBefore = this.dims.slice(0, start_dim); - let dimsToFlatten = this.dims.slice(start_dim, end_dim + 1); - let dimsToKeepAfter = this.dims.slice(end_dim + 1); + const dimsToKeepBefore = this.dims.slice(0, start_dim); + const dimsToFlatten = this.dims.slice(start_dim, end_dim + 1); + const dimsToKeepAfter = this.dims.slice(end_dim + 1); - this.dims = [...dimsToKeepBefore, dimsToFlatten.reduce((a, b) => a * b, 1), ...dimsToKeepAfter] + this.dims = [...dimsToKeepBefore, dimsToFlatten.reduce((a, b) => a * b, 1), ...dimsToKeepAfter]; return this; } @@ -638,7 +630,7 @@ export class Tensor { * @param {number} end_dim the last dim to flatten * @returns {Tensor} The flattened tensor. */ - flatten(start_dim = 0, end_dim = -1) { + flatten(start_dim: number = 0, end_dim: number = -1): Tensor { return this.clone().flatten_(start_dim, end_dim); } @@ -647,8 +639,7 @@ export class Tensor { * @param {...number} dims the desired size * @returns {Tensor} The tensor with the same data but different shape */ - view(...dims) { - // TODO: validate dims + view(...dims: number[]): Tensor { let inferredIndex = -1; for (let i = 0; i < dims.length; ++i) { if (dims[i] === -1) { @@ -659,26 +650,31 @@ export class Tensor { } } - const this_data = this.data; if (inferredIndex !== -1) { // Some dimension must be inferred const productOther = dims.reduce((product, curr, index) => { - return index !== inferredIndex ? product * curr : product + return index !== inferredIndex ? product * curr : product; }, 1); - dims[inferredIndex] = this_data.length / productOther; + dims[inferredIndex] = this.data.length / productOther; } - return new Tensor(this.type, this_data, dims); // NOTE: uses same underlying storage + return new Tensor(this.type, this.data, dims); // NOTE: uses same underlying storage } - neg_() { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] = -this_data[i]; + /** + * Negate the tensor in place + */ + neg_(): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = -this.data[i]; } return this; } - neg() { + + /** + * Return a new negated tensor + */ + neg(): Tensor { return this.clone().neg_(); } @@ -687,11 +683,10 @@ export class Tensor { * @param {number} val The value to compare with. * @returns {Tensor} A boolean tensor that is `true` where input is greater than other and `false` elsewhere. */ - gt(val) { + gt(val: number): Tensor { const mask = new Uint8Array(this.data.length); - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - mask[i] = this_data[i] > val ? 1 : 0; + for (let i = 0; i < this.data.length; ++i) { + mask[i] = this.data[i] > val ? 1 : 0; } return new Tensor('bool', mask, this.dims); } @@ -701,22 +696,20 @@ export class Tensor { * @param {number} val The value to compare with. * @returns {Tensor} A boolean tensor that is `true` where input is less than other and `false` elsewhere. */ - lt(val) { + lt(val: number): Tensor { const mask = new Uint8Array(this.data.length); - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - mask[i] = this_data[i] < val ? 1 : 0; + for (let i = 0; i < this.data.length; ++i) { + mask[i] = this.data[i] < val ? 1 : 0; } return new Tensor('bool', mask, this.dims); } /** - * In-place version of @see {@link Tensor.clamp} + * In-place version of clamp */ - clamp_(min, max) { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] = Math.min(Math.max(this_data[i], min), max); + clamp_(min: number, max: number): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = Math.min(Math.max(this.data[i], min), max); } return this; } @@ -727,17 +720,16 @@ export class Tensor { * @param {number} max upper-bound of the range to be clamped to * @returns {Tensor} the output tensor. */ - clamp(min, max) { + clamp(min: number, max: number): Tensor { return this.clone().clamp_(min, max); } /** - * In-place version of @see {@link Tensor.round} + * In-place version of round */ - round_() { - const this_data = this.data; - for (let i = 0; i < this_data.length; ++i) { - this_data[i] = Math.round(this_data[i]); + round_(): Tensor { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] = Math.round(this.data[i]); } return this; } @@ -746,42 +738,79 @@ export class Tensor { * Rounds elements of input to the nearest integer. * @returns {Tensor} the output tensor. */ - round() { + round(): Tensor { return this.clone().round_(); } - mean(dim = null, keepdim = false) { + /** + * Returns the mean value of each row of the input tensor in the given dimension. + */ + mean(dim: number | null = null, keepdim = false): Tensor { return mean(this, dim, keepdim); } - min(dim = null, keepdim = false) { + /** + * Returns the minimum value of the tensor. + */ + min(dim: number | null = null, keepdim = false): Tensor { if (dim === null) { // None to reduce over all dimensions. const val = min(this.data)[0]; return new Tensor(this.type, [val], [/* scalar */]); } - const [type, result, resultDims] = reduce_helper((a, b) => Math.min(a, b), this, dim, keepdim, Infinity); + const [type, result, resultDims] = reduce_helper( + (a: number, b: number) => Math.min(a, b), + this, + dim, + keepdim, + Infinity + ); return new Tensor(type, result, resultDims); } - max(dim = null, keepdim = false) { + /** + * Returns the maximum value of the tensor. + * @param {number} [dim=null] The dimension or dimensions to reduce. If `null`, all dimensions are reduced. + * @param {boolean} keepdim Whether the output tensors have `dim` retained or not. + * @returns {Tensor} The maximum value of the tensor. + */ + max(dim: number | null = null, keepdim: boolean = false): Tensor { if (dim === null) { // None to reduce over all dimensions. const val = max(this.data)[0]; return new Tensor(this.type, [val], [/* scalar */]); } - const [type, result, resultDims] = reduce_helper((a, b) => Math.max(a, b), this, dim, keepdim, -Infinity); + const [type, result, resultDims] = reduce_helper( + (a: number, b: number) => Math.max(a, b), + this, + dim, + keepdim, + -Infinity + ); return new Tensor(type, result, resultDims); } - argmin(dim = null, keepdim = false) { + /** + * Returns the indices of the minimum values along a dimension. + * @param {number} [dim=null] The dimension or dimensions to reduce. If `null`, all dimensions are reduced. + * @param {boolean} keepdim Whether the output tensors have `dim` retained or not. + * @returns {Tensor} The indices of the minimum values along a dimension. + */ + argmin(dim: number | null = null, keepdim: boolean = false): Tensor { if (dim !== null) { throw new Error("`dim !== null` not yet implemented."); } const index = min(this.data)[1]; return new Tensor('int64', [BigInt(index)], []); } - argmax(dim = null, keepdim = false) { + + /** + * Returns the indices of the maximum values along a dimension. + * @param {number} [dim=null] The dimension or dimensions to reduce. If `null`, all dimensions are reduced. + * @param {boolean} keepdim Whether the output tensors have `dim` retained or not. + * @returns {Tensor} The indices of the maximum values along a dimension. + */ + argmax(dim: number | null = null, keepdim: boolean = false): Tensor { if (dim !== null) { throw new Error("`dim !== null` not yet implemented."); } @@ -794,7 +823,7 @@ export class Tensor { * @param {DataType} type The desired data type. * @returns {Tensor} The converted tensor. */ - to(type) { + to(type: DataType): Tensor { // If the self Tensor already has the correct dtype, then self is returned. if (this.type === type) return this; @@ -824,33 +853,38 @@ export class Tensor { * This creates a nested array of a given type and depth (see examples). * * @example - * NestArray; // string[] - * @example - * NestArray; // number[][] - * @example - * NestArray; // string[][][] etc. - * @template T - * @template {number} Depth - * @template {never[]} [Acc=[]] - * @typedef {Acc['length'] extends Depth ? T : NestArray} NestArray - */ +* NestArray; // string[] +* @example +* NestArray; // number[][] +* @example +* NestArray; // string[][][] etc. +* @template T +* @template {number} Depth +* @template {never[]} [Acc=[]] +* @typedef {Acc['length'] extends Depth ? T : NestArray} NestArray +*/ + +type NestArray = + Acc['length'] extends Depth + ? T + : NestArray; + /** * Reshapes a 1-dimensional array into an n-dimensional array, according to the provided dimensions. * * @example - * reshape([10 ], [1 ]); // Type: number[] Value: [10] - * reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]] - * reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] - * reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]] - * @param {T[]|DataArray} data The input array to reshape. - * @param {DIM} dimensions The target shape/dimensions. - * @template T - * @template {[number]|number[]} DIM - * @returns {NestArray} The reshaped array. - */ -function reshape(data, dimensions) { - +* reshape([10 ], [1 ]); // Type: number[] Value: [10] +* reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]] +* reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] +* reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]] +* @param {T[]|DataArray} data The input array to reshape. +* @param {DIM} dimensions The target shape/dimensions. +* @template T +* @template {[number]|number[]} DIM +* @returns {NestArray} The reshaped array. +*/ +function reshape(data: ArrayLike, dimensions: number[]): NestArray { const totalElements = data.length; const dimensionSize = dimensions.reduce((a, b) => a * b); @@ -858,24 +892,17 @@ function reshape(data, dimensions) { throw Error(`cannot reshape array of size ${totalElements} into shape (${dimensions})`); } - /** @type {any} */ - let reshapedArray = data; + let reshapedArray: any = Array.from(data); for (let i = dimensions.length - 1; i >= 0; i--) { - reshapedArray = reshapedArray.reduce((acc, val) => { - let lastArray = acc[acc.length - 1]; - - if (lastArray.length < dimensions[i]) { - lastArray.push(val); - } else { - acc.push([val]); - } - - return acc; - }, [[]]); + const temp: any[] = []; + for (let j = 0; j < reshapedArray.length; j += dimensions[i]) { + temp.push(reshapedArray.slice(j, j + dimensions[i])); + } + reshapedArray = temp; } - return reshapedArray[0]; + return reshapedArray; } /** @@ -884,12 +911,11 @@ function reshape(data, dimensions) { * @param {Array} axes The axes to permute the tensor along. * @returns {Tensor} The permuted tensor. */ -export function permute(tensor, axes) { - const [permutedData, shape] = permute_data(tensor.data, tensor.dims, axes); +export function permute(tensor: Tensor, axes: number[]): Tensor { + const [permutedData, shape] = permute_data(tensor.data as TypedArray, tensor.dims, axes); return new Tensor(tensor.type, permutedData, shape); } - /** * Interpolates an Tensor to the given size. * @param {Tensor} input The input tensor to interpolate. Data must be channel-first (i.e., [c, h, w]) @@ -898,15 +924,19 @@ export function permute(tensor, axes) { * @param {boolean} align_corners Whether to align corners. * @returns {Tensor} The interpolated tensor. */ -export function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) { - +export function interpolate( + input: Tensor, + [out_height, out_width]: [number, number], + mode: string = 'bilinear', + align_corners: boolean = false +): Tensor { // Input image dimensions const in_channels = input.dims.at(-3) ?? 1; const in_height = input.dims.at(-2); const in_width = input.dims.at(-1); let output = interpolate_data( - /** @type {import('./maths.js').TypedArray}*/(input.data), + input.data as TypedArray, [in_channels, in_height, in_width], [out_height, out_width], mode, @@ -915,7 +945,6 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a return new Tensor(input.type, output, [in_channels, out_height, out_width]); } - /** * Down/up samples the input. * Inspired by https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. @@ -925,11 +954,13 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a * @param {"nearest"|"bilinear"|"bicubic"} [options.mode='bilinear'] algorithm used for upsampling * @returns {Promise} The interpolated tensor. */ -export async function interpolate_4d(input, { +export async function interpolate_4d(input: Tensor, { size = null, mode = 'bilinear', -} = {}) { - +}: { + size?: [number, number] | [number, number, number] | [number, number, number, number] | null, + mode?: 'nearest' | 'bilinear' | 'bicubic' +} = {}): Promise { // Error checking if (input.dims.length !== 4) { throw new Error('`interpolate_4d` currently only supports 4D input.'); @@ -940,7 +971,7 @@ export async function interpolate_4d(input, { } // Fill in missing dimensions - let targetDims; + let targetDims: number[]; if (size.length === 2) { targetDims = [...input.dims.slice(0, 2), ...size]; } else if (size.length === 3) { @@ -973,7 +1004,7 @@ export async function interpolate_4d(input, { * @param {Tensor} b the second tensor to be multiplied * @returns {Promise} The matrix product of the two tensors. */ -export async function matmul(a, b) { +export async function matmul(a: Tensor, b: Tensor): Promise { const op = await TensorOpRegistry.matmul; return await op({ a, b }); } @@ -985,12 +1016,11 @@ export async function matmul(a, b) { * @param {Tensor} a The dimension along which to take the one dimensional real FFT. * @returns {Promise} the output tensor. */ -export async function rfft(x, a) { +export async function rfft(x: Tensor, a: Tensor): Promise { const op = await TensorOpRegistry.rfft; return await op({ x, a }); } - /** * Returns the k largest elements of the given input tensor. * Inspired by https://pytorch.org/docs/stable/generated/torch.topk.html @@ -998,7 +1028,7 @@ export async function rfft(x, a) { * @param {number} [k] the k in "top-k" * @returns {Promise<[Tensor, Tensor]>} the output tuple of (Tensor, LongTensor) of top-k elements and their indices. */ -export async function topk(x, k) { +export async function topk(x: Tensor, k?: number): Promise<[Tensor, Tensor]> { const op = await TensorOpRegistry.top_k; if (k == null) { @@ -1016,8 +1046,13 @@ export async function topk(x, k) { }); } +/** + * Helper function to convert array to index tensor + */ +function arrayToIndexTensor(array: number[]): Tensor { + return new Tensor('int64', array.map(BigInt), [array.length]); +} -const arrayToIndexTensor = (array) => new Tensor('int64', array, [array.length]); /** * Slice a multidimensional float32 tensor. * @param {Tensor} data: Tensor of data to extract slices from @@ -1027,7 +1062,13 @@ const arrayToIndexTensor = (array) => new Tensor('int64', array, [array.length]) * @param {number[]} [steps]: 1-D array of slice step of corresponding axis in axes. * @returns {Promise} Sliced data tensor. */ -export async function slice(data, starts, ends, axes, steps) { +export async function slice( + data: Tensor, + starts: number[], + ends: number[], + axes: number[], + steps?: number[] +): Promise { const op = await TensorOpRegistry.slice; return await op({ x: data, @@ -1038,23 +1079,21 @@ export async function slice(data, starts, ends, axes, steps) { }); } - /** * Perform mean pooling of the last hidden state followed by a normalization step. * @param {Tensor} last_hidden_state Tensor of shape [batchSize, seqLength, embedDim] * @param {Tensor} attention_mask Tensor of shape [batchSize, seqLength] * @returns {Tensor} Returns a new Tensor of shape [batchSize, embedDim]. */ -export function mean_pooling(last_hidden_state, attention_mask) { +export function mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor): Tensor { // last_hidden_state: [batchSize, seqLength, embedDim] // attention_mask: [batchSize, seqLength] - const lastHiddenStateData = last_hidden_state.data; + const lastHiddenStateData = last_hidden_state.data as number[]; const attentionMaskData = attention_mask.data; const shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]]; - // @ts-ignore - const returnedData = new lastHiddenStateData.constructor(shape[0] * shape[1]); + const returnedData = new (lastHiddenStateData.constructor as new (length: number) => typeof lastHiddenStateData)(shape[0] * shape[1]); const [batchSize, seqLength, embedDim] = last_hidden_state.dims; let outIndex = 0; @@ -1085,7 +1124,7 @@ export function mean_pooling(last_hidden_state, attention_mask) { last_hidden_state.type, returnedData, shape - ) + ); } /** @@ -1096,9 +1135,15 @@ export function mean_pooling(last_hidden_state, attention_mask) { * @param {number} [options.eps=1e-5] A value added to the denominator for numerical stability. * @returns {Tensor} The normalized tensor. */ -export function layer_norm(input, normalized_shape, { - eps = 1e-5, -} = {}) { +export function layer_norm( + input: Tensor, + normalized_shape: number[], + { + eps = 1e-5, + }: { + eps?: number + } = {} +): Tensor { if (input.dims.length !== 2) { throw new Error('`layer_norm` currently only supports 2D input.'); } @@ -1110,13 +1155,12 @@ export function layer_norm(input, normalized_shape, { } const [std, mean] = std_mean(input, 1, 0, true); - const stdData = /** @type {Float32Array} */(std.data); - const meanData = /** @type {Float32Array} */(mean.data); + const stdData = std.data as Float32Array; + const meanData = mean.data as Float32Array; - const inputData = /** @type {Float32Array} */(input.data); + const inputData = input.data as Float32Array; - // @ts-ignore - const returnedData = new inputData.constructor(inputData.length); + const returnedData = new (inputData.constructor as new (length: number) => typeof inputData)(inputData.length); for (let i = 0; i < batchSize; ++i) { const offset = i * featureDim; @@ -1126,7 +1170,7 @@ export function layer_norm(input, normalized_shape, { } } return new Tensor(input.type, returnedData, input.dims); -} +} /** * Helper function to calculate new dimensions when performing a squeeze operation. @@ -1135,7 +1179,7 @@ export function layer_norm(input, normalized_shape, { * @returns {number[]} The new dimensions. * @private */ -function calc_squeeze_dims(dims, dim) { +function calc_squeeze_dims(dims: number[], dim: number | number[] | null): number[] { dims = dims.slice(); if (dim === null) { dims = dims.filter((d) => d !== 1); @@ -1158,7 +1202,7 @@ function calc_squeeze_dims(dims, dim) { * @returns {number[]} The new dimensions. * @private */ -function calc_unsqueeze_dims(dims, dim) { +function calc_unsqueeze_dims(dims: number[], dim: number): number[] { // Dimension out of range (e.g., "expected to be in range of [-4, 3], but got 4") // + 1 since we allow inserting at the end (i.e. dim = -1) dim = safeIndex(dim, dims.length + 1); @@ -1178,7 +1222,7 @@ function calc_unsqueeze_dims(dims, dim) { * @throws {Error} If the index is out of range. * @private */ -function safeIndex(index, size, dimension = null, boundsCheck = true) { +function safeIndex(index: number, size: number, dimension: number = null, boundsCheck = true): number { if (boundsCheck && (index < -size || index >= size)) { throw new Error(`IndexError: index ${index} is out of bounds for dimension${dimension === null ? '' : ' ' + dimension} with size ${size}`); } @@ -1196,7 +1240,7 @@ function safeIndex(index, size, dimension = null, boundsCheck = true) { * @param {number} dim The dimension to concatenate along. * @returns {Tensor} The concatenated tensor. */ -export function cat(tensors, dim = 0) { +export function cat(tensors: Tensor[], dim: number = 0): Tensor { dim = safeIndex(dim, tensors[0].dims.length); // TODO do validation of shapes @@ -1214,16 +1258,13 @@ export function cat(tensors, dim = 0) { if (dim === 0) { // Handle special case for performance reasons - let offset = 0; for (const tensor of tensors) { const tensorData = tensor.data; result.set(tensorData, offset); offset += tensorData.length; } - } else { - let currentDim = 0; for (let t = 0; t < tensors.length; ++t) { @@ -1260,41 +1301,48 @@ export function cat(tensors, dim = 0) { * @param {number} dim The dimension to stack along. * @returns {Tensor} The stacked tensor. */ -export function stack(tensors, dim = 0) { +export function stack(tensors: Tensor[], dim: number = 0): Tensor { // TODO do validation of shapes // NOTE: stack expects each tensor to be equal size return cat(tensors.map(t => t.unsqueeze(dim)), dim); } - /** * @param {(previousValue: any, currentValue: any, currentIndex?: number, resultIndex?: number) => any} callbackfn * @param {Tensor} input the input tensor. * @param {number|null} dim the dimension to reduce. * @param {boolean} keepdim whether the output tensor has dim retained or not. - * @returns {[DataType, any, number[]]} The reduced tensor data. + * @returns {[DataType, DataArray, number[]]} The reduced tensor data. */ -function reduce_helper(callbackfn, input, dim = null, keepdim = false, initialValue = null) { +function reduce_helper( + callbackfn: (previousValue: any, currentValue: any, currentIndex?: number, resultIndex?: number) => any, + input: Tensor, + dim: number | null = null, + keepdim: boolean = false, + initialValue: number | bigint | null = null +): [DataType, DataArray, number[]] { const inputData = input.data; const inputDims = input.dims; // Negative indexing - dim = safeIndex(dim, inputDims.length); + dim = dim !== null ? safeIndex(dim, inputDims.length) : null; // Calculate the shape of the resulting array after summation const resultDims = inputDims.slice(); // Copy the original dimensions - resultDims[dim] = 1; // Remove the specified axis + if (dim !== null) { + resultDims[dim] = 1; // Remove the specified axis + } // Create a new array to store the accumulated values - // @ts-ignore - const result = new inputData.constructor(inputData.length / inputDims[dim]); + const result = new (inputData.constructor as new (length: number) => typeof inputData)( + inputData.length / (dim !== null ? inputDims[dim] : 1) + ); if (initialValue !== null) { result.fill(initialValue); } // Iterate over the data array for (let i = 0; i < inputData.length; ++i) { - // Calculate the index in the resulting array let resultIndex = 0; @@ -1312,12 +1360,13 @@ function reduce_helper(callbackfn, input, dim = null, keepdim = false, initialVa result[resultIndex] = callbackfn(result[resultIndex], inputData[i], i, resultIndex); } - if (!keepdim) resultDims.splice(dim, 1); + if (!keepdim && dim !== null) resultDims.splice(dim, 1); return [input.type, result, resultDims]; } +/** /** * Calculates the standard deviation and mean over the dimensions specified by dim. dim can be a single dimension or `null` to reduce over all dimensions. * @param {Tensor} input the input tenso @@ -1326,8 +1375,13 @@ function reduce_helper(callbackfn, input, dim = null, keepdim = false, initialVa * @param {boolean} keepdim whether the output tensor has dim retained or not. * @returns {Tensor[]} A tuple of (std, mean) tensors. */ -export function std_mean(input, dim = null, correction = 1, keepdim = false) { - const inputData = /** @type {Float32Array} */(input.data); +export function std_mean( + input: Tensor, + dim: number | null = null, + correction: number = 1, + keepdim: boolean = false +): [Tensor, Tensor] { + const inputData = input.data as Float32Array; const inputDims = input.dims; if (dim === null) { @@ -1343,10 +1397,15 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { } dim = safeIndex(dim, inputDims.length); const meanTensor = mean(input, dim, keepdim); - const meanTensorData = meanTensor.data; + const meanTensorData = meanTensor.data as number[]; // Compute squared sum - const [type, result, resultDims] = reduce_helper((a, b, i, j) => a + (b - meanTensorData[j]) ** 2, input, dim, keepdim); + const [type, result, resultDims] = reduce_helper( + (a: number, b: number, i: number, j: number) => a + (b - meanTensorData[j]) ** 2, + input, + dim, + keepdim + ); // Square root of the squared sum for (let i = 0; i < result.length; ++i) { @@ -1365,9 +1424,9 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { * @param {boolean} keepdim whether the output tensor has dim retained or not. * @returns {Tensor} A new tensor with means taken along the specified dimension. */ -export function mean(input, dim = null, keepdim = false) { +export function mean(input: Tensor, dim: number | null = null, keepdim: boolean = false): Tensor { const inputDims = input.dims; - const inputData = /** @type {Float32Array} */(input.data); + const inputData = input.data as number[]; if (dim === null) { // None to reduce over all dimensions. @@ -1377,7 +1436,7 @@ export function mean(input, dim = null, keepdim = false) { dim = safeIndex(dim, inputDims.length); // Compute sum - const [type, result, resultDims] = reduce_helper((a, b) => a + b, input, dim, keepdim); + const [type, result, resultDims] = reduce_helper((a: number, b: number) => a + b, input, dim, keepdim); // Divide by number of elements in the dimension if (inputDims[dim] !== 1) { @@ -1389,8 +1448,10 @@ export function mean(input, dim = null, keepdim = false) { return new Tensor(type, result, resultDims); } - -function dimsToStride(dims) { +/** + * Helper function to calculate strides from dimensions + */ +function dimsToStride(dims: number[]): number[] { const stride = new Array(dims.length); for (let i = dims.length - 1, s2 = 1; i >= 0; --i) { stride[i] = s2; @@ -1399,13 +1460,21 @@ function dimsToStride(dims) { return stride; } -function fullHelper(size, fill_value, dtype, cls) { +/** + * Helper function for creating filled tensors + */ +function fullHelper( + size: number[], + fill_value: T, + dtype: DataType, + cls: new (length: number) => { fill(value: T): any } +): Tensor { const numElements = size.reduce((a, b) => a * b, 1); return new Tensor( dtype, new cls(numElements).fill(fill_value), size - ) + ); } /** @@ -1414,9 +1483,10 @@ function fullHelper(size, fill_value, dtype, cls) { * @param {number|bigint|boolean} fill_value The value to fill the output tensor with. * @returns {Tensor} The filled tensor. */ -export function full(size, fill_value) { - let dtype; - let typedArrayCls; +export function full(size: number[], fill_value: number | bigint | boolean): Tensor { + let dtype: DataType; + let typedArrayCls: new (length: number) => { fill(value: any): any }; + if (typeof fill_value === 'number') { dtype = 'float32'; typedArrayCls = Float32Array; @@ -1433,7 +1503,10 @@ export function full(size, fill_value) { return fullHelper(size, fill_value, dtype, typedArrayCls); } -export function full_like(tensor, fill_value) { +/** + * Returns a tensor filled with a scalar value, with the same size as input. + */ +export function full_like(tensor: Tensor, fill_value: number | bigint | boolean): Tensor { return full(tensor.dims, fill_value); } @@ -1442,7 +1515,7 @@ export function full_like(tensor, fill_value) { * @param {number[]} size A sequence of integers defining the shape of the output tensor. * @returns {Tensor} The ones tensor. */ -export function ones(size) { +export function ones(size: number[]): Tensor { return fullHelper(size, 1n, 'int64', BigInt64Array); } @@ -1451,7 +1524,7 @@ export function ones(size) { * @param {Tensor} tensor The size of input will determine size of the output tensor. * @returns {Tensor} The ones tensor. */ -export function ones_like(tensor) { +export function ones_like(tensor: Tensor): Tensor { return ones(tensor.dims); } @@ -1460,7 +1533,7 @@ export function ones_like(tensor) { * @param {number[]} size A sequence of integers defining the shape of the output tensor. * @returns {Tensor} The zeros tensor. */ -export function zeros(size) { +export function zeros(size: number[]): Tensor { return fullHelper(size, 0n, 'int64', BigInt64Array); } @@ -1469,7 +1542,7 @@ export function zeros(size) { * @param {Tensor} tensor The size of input will determine size of the output tensor. * @returns {Tensor} The zeros tensor. */ -export function zeros_like(tensor) { +export function zeros_like(tensor: Tensor): Tensor { return zeros(tensor.dims); } @@ -1478,13 +1551,13 @@ export function zeros_like(tensor) { * @param {number[]} size A sequence of integers defining the shape of the output tensor. * @returns {Tensor} The random tensor. */ -export function rand(size) { +export function rand(size: number[]): Tensor { const length = size.reduce((a, b) => a * b, 1); return new Tensor( "float32", Float32Array.from({ length }, () => Math.random()), size, - ) + ); } /** @@ -1493,23 +1566,20 @@ export function rand(size) { * @param {'binary'|'ubinary'} precision The precision to use for quantization. * @returns {Tensor} The quantized tensor. */ -export function quantize_embeddings(tensor, precision) { +export function quantize_embeddings(tensor: Tensor, precision: 'binary' | 'ubinary'): Tensor { if (tensor.dims.length !== 2) { throw new Error("The tensor must have 2 dimensions"); } if (tensor.dims.at(-1) % 8 !== 0) { throw new Error("The last dimension of the tensor must be a multiple of 8"); } - if (!['binary', 'ubinary'].includes(precision)) { - throw new Error("The precision must be either 'binary' or 'ubinary'"); - } const signed = precision === 'binary'; const dtype = signed ? 'int8' : 'uint8'; // Create a typed array to store the packed bits const cls = signed ? Int8Array : Uint8Array; - const inputData = tensor.data; + const inputData = tensor.data as number[]; const outputData = new cls(inputData.length / 8); // Iterate over each number in the array @@ -1526,7 +1596,8 @@ export function quantize_embeddings(tensor, precision) { if (signed && bitPosition === 0) { outputData[arrayIndex] -= 128; } - }; + } return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]); } + diff --git a/tsconfig.json b/tsconfig.json index fb6de7097..40df2fd81 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -5,8 +5,8 @@ // Tells the compiler to check JS files "checkJs": true, "target": "esnext", - "module": "nodenext", - "moduleResolution": "nodenext", + "module": "esnext", + "moduleResolution": "node", "outDir": "types", "strict": false, "skipLibCheck": true,