diff --git a/src/utils/tensor.js b/src/utils/tensor.js index caecac814..4146067c4 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -41,8 +41,13 @@ export const DataTypeMap = Object.freeze({ }); /** - * @typedef {keyof typeof DataTypeMap} DataType - * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray + * @typedef {keyof typeof DataTypeMap} DataType A Tensor data type, for example `uint8`. + * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray A typed array or an array of values. + * @typedef {{x: number, y: number}} Point An object representing coordinates. + * @typedef {'RECT' | 'CROSS' | 'ELLIPSE'} Shape A shape for morphological operations. + * @typedef {{width: number, height: number}} Size An object representing the size of an object. + * + * @typedef {number | [number, number] | Size} KernelSize A kernel size for morphological operations. */ @@ -791,6 +796,189 @@ export class Tensor { return new Tensor('int64', [BigInt(index)], []); } + /** + * Mutates the data through a dilation morphological operation. + * + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} Returns `this`. + */ + async dilate_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { + const this_data = this.data; + const data = await this.morphologicalOperation('DILATE', this_data, kernelSize, shape, anchor); + for (let i = 0; i < this_data.length; ++i) { + this.data[i] = data[i]; + } + return this; + } + + /** + * Returns a new Tensor where the data is mutated through a dilation + * morphological operation. + * + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} The new Tensor. + */ + async dilate(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { + return this.clone().dilate_(kernelSize, shape, anchor); + } + + /** + * * Mutates the data through a erosion morphological operation. + * + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} Returns `this`. + */ + async erode_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { + const this_data = this.data; + const data = await this.morphologicalOperation('ERODE', this_data, kernelSize, shape, anchor); + for (let i = 0; i < this_data.length; ++i) { + this.data[i] = data[i]; + } + return this; + } + + /** + * Returns a new Tensor where the data is mutated through a erosion + * morphological operation. + * + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} The new Tensor. + */ + async erode(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { + return this.clone().erode_(kernelSize, shape, anchor); + } + + /** + * Applies a morphological operation to this tensor. + * + * @param {'DILATE' | 'ERODE'} operation The operation to apply. + * @param {DataArray} data The input tensor data. + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} The cloned, modified output tensor. + */ + async morphologicalOperation(operation, data, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) { + kernelSize = validateKernel(kernelSize); + // We don't need to perform the operation if the kernel is empty. + if (kernelSize.width * kernelSize.height === 1) { + return; + } + + anchor = normalizeAnchor(anchor, kernelSize); + let kernel = getStructuringElement(shape, kernelSize, anchor); + + const [batches, rows, cols] = this.dims; + const paddingSize = { width: Math.floor(kernelSize.width / 2), height: Math.floor(kernelSize.height / 2) }; + const outputData = new Float32Array(this.data.length); + const operationFunction = (operationType => { + switch (operationType) { + case 'DILATE': + return Math.max; + case 'ERODE': + return Math.min; + default: + throw new Error(`Unknown operation: ${operationType}`); + } + })(operation); + + const processChunk = async chunk => { + for (const { batchIndex, rowIndex, colIndex } of chunk) { + const kernelValues = []; + + // Collect values in the kernel window. + for (let kernelRowOffset = -paddingSize.height; kernelRowOffset <= paddingSize.height; kernelRowOffset++) { + for (let kernelColOffset = -paddingSize.width; kernelColOffset <= paddingSize.width; kernelColOffset++) { + const neighborRowIndex = rowIndex + kernelRowOffset; + const neighborColIndex = colIndex + kernelColOffset; + if (neighborRowIndex >= 0 && neighborRowIndex < rows && neighborColIndex >= 0 && neighborColIndex < cols) { + const neighborIndex = (batchIndex * rows * cols) + neighborRowIndex * cols + neighborColIndex; + // Only include values where the kernel has a value + // of 1. + // Rather than multiply against this value, we use + // the if check to reduce the size of the array. + const kernelValue = kernel[kernelRowOffset + paddingSize.height][kernelColOffset + paddingSize.width]; + if (kernelValue === 1) { + kernelValues.push(data[neighborIndex] * kernelValue); + } + } + } + } + + // Apply operation function to the values. + const outputIndex = batchIndex * rows * cols + rowIndex * cols + colIndex; + outputData[outputIndex] = operationFunction(...kernelValues); + } + }; + + // Divide work into chunks for parallel processing. + const chunks = []; + const chunkSize = Math.ceil((batches * rows * cols) / (navigator.hardwareConcurrency || 4)); + let currentChunk = []; + + for (let rowIndex = 0; rowIndex < rows; rowIndex++) { + for (let colIndex = 0; colIndex < cols; colIndex++) { + for (let batchIndex = 0; batchIndex < batches; batchIndex++) { + currentChunk.push({ batchIndex, rowIndex, colIndex }); + // Store the chunk now that it is the right size. + if (currentChunk.length >= chunkSize) { + chunks.push([...currentChunk]); + currentChunk = []; + } + } + } + } + // Get any elements that may not fit neatly in the defined chunk size. + if (currentChunk.length > 0) { + chunks.push(currentChunk); + } + + // Process all chunks in parallel. + await Promise.all(chunks.map(chunk => processChunk(chunk))); + + return outputData; + } + + /** + * Performs a morphological operation on the input image. + * + * @param {'ERODE' | 'DILATE' | 'OPEN' | 'CLOSE'} operation + * @param {KernelSize} kernelSize The width and height of the kernel. + * @param {Shape} [shape='RECT'] The shape of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Promise} The cloned, modified output tensor. + */ + async morph(operation, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) { + switch (operation) { + case 'ERODE': + return this.erode(kernelSize, shape, anchor); + + case 'DILATE': + return this.dilate(kernelSize, shape, anchor); + + case 'OPEN': + return (await this + .erode_(kernelSize, shape, anchor)) + .dilate_(kernelSize, shape, anchor); + + case 'CLOSE': + return (await this + .dilate_(kernelSize, shape, anchor)) + .erode_(kernelSize, shape, anchor); + + default: + throw new Error("Unknown morphological operation"); + } + } + /** * Performs Tensor dtype conversion. * @param {DataType} type The desired data type. @@ -1532,3 +1720,141 @@ export function quantize_embeddings(tensor, precision) { return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]); } + +/** + * Ensure that an anchor lies within the kernel size. + * Passing in a `-1` will center the anchor. + * + * @param {Point} anchor The input anchor point. + * @param {Size} kernelSize The width and height of the kernel. + * @returns {Point} The normalized anchor point. + */ +function normalizeAnchor(anchor, kernelSize) { + // Centralize the x coordinate. + if (anchor.x === -1) { + anchor.x = Math.floor(kernelSize.width / 2); + } + // Centralize the y coordinate. + if (anchor.y === -1) { + anchor.y = Math.floor(kernelSize.height / 2); + } + // Check if the anchor is within the kernel size. + if (anchor.x < 0 || anchor.x >= kernelSize.width || + anchor.y < 0 || anchor.y >= kernelSize.height + ) { + throw new Error("Anchor is out of bounds for the given kernel size."); + } + return anchor; +} + +/** + * Creates a Size object that represents a kernel. + * Performs some validation on the kernel size. + * + * @param {KernelSize} kernelSize The size of the kernel. + * @returns {Size} An object representing the kernel width and height. + * @throws {Error} If the kernel size is invalid. + * @throws {Error} If kernel size is even. + */ +function validateKernel(kernelSize) { + let kernel; + if (typeof kernelSize === 'object' && 'width' in kernelSize && 'height' in kernelSize) { + // This is a Size object, so no conversion required. + kernel = kernelSize; + } else if (typeof kernelSize === 'number' && Number.isInteger(kernelSize)) { + // A single whole number is assumed as the width and height. + kernel = { width: kernelSize, height: kernelSize }; + } else if (Array.isArray(kernelSize) && kernelSize.length === 2 && kernelSize.every(Number.isInteger)) { + // An array of two values is assumed as width then height. + kernel = { width: kernelSize[0], height: kernelSize[1] }; + } else { + throw new Error("Invalid kernel size."); + } + + if (kernel.width % 2 === 0 || kernel.height % 2 === 0) { + throw new Error("Kernel size must be odd"); + } + + return kernel; +} + +/** + * Creates a structuring element for morphological operations. + * + * This function is a JavaScript translation of the [OpenCV C++ function of the same name](https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L981). + * + * @param {Shape} shape The shape of the kernel. + * @param {Size} kernelSize The width and height of the kernel. + * @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. + * @returns {Array>} The structuring element as a 2D array. + * @throws {Error} If the shape, or kernel size, is invalid is invalid. + */ +function getStructuringElement(shape, kernelSize, anchor = { x: -1, y: -1 }) { + if (!['RECT', 'CROSS', 'ELLIPSE'].includes(shape)) { + throw new Error("Invalid shape. Must be 'RECT', 'CROSS', or 'ELLIPSE'."); + } + + // Get a kernel object that represents the kernel width and height. + let kernel = validateKernel(kernelSize); + + // Normalize anchor to default to the center if not specified. + anchor = normalizeAnchor(anchor, kernel); + + // If the kernel size is 1x1, treat as a rectangle. + if (kernel.width === 1 && kernel.height === 1) { + shape = 'RECT'; + } + + let rowRadius = 0; // Radius along the height. + let colRadius = 0; // Radius along the width. + let inverseRowRadiusSquared = 0; // Inverse squared radius for ellipses. + + if (shape === 'ELLIPSE') { + // Calculate radii and inverse squared radius for the ellipse equation. + rowRadius = Math.floor(kernel.height / 2); + colRadius = Math.floor(kernel.width / 2); + inverseRowRadiusSquared = rowRadius > 0 ? 1 / (rowRadius * rowRadius) : 0; + } + + // Create a 2D array to represent the kernel. + const kernelArray = Array.from({ length: kernel.height }, () => Array(kernel.width).fill(0)); + + for (let row = 0; row < kernel.height; row++) { + let startColumn = 0; + let endColumn = 0; + + if (shape === 'RECT' || (shape === 'CROSS' && row === anchor.y)) { + // Full width for rectangle or horizontal line for cross shape. + endColumn = kernel.width; + } else if (shape === 'CROSS') { + // Single column for cross shape. + // A cross will be a single row and column, so only add 1. + startColumn = anchor.x; + endColumn = startColumn + 1; + } else if (shape === 'ELLIPSE') { + // Calculate elliptical bounds for this row. + + // Distance from the anchor row. + const verticalOffset = row - anchor.y; + + if (Math.abs(verticalOffset) <= rowRadius) { + // Solve for horizontal bounds using the ellipse equation: x^2/a^2 + y^2/b^2 = 1 + const horizontalRadius = Math.floor( + colRadius * Math.sqrt(Math.max(0, (rowRadius * rowRadius) - (verticalOffset * verticalOffset)) * inverseRowRadiusSquared) + ); + + // Left and right bound of the ellipse. + // Add 1 to endColumn because it's not inclusive in the for loop. + startColumn = Math.max(anchor.x - horizontalRadius, 0); + endColumn = Math.min(anchor.x + horizontalRadius + 1, kernel.width); + } + } + + // Fill the kernel row with 1s within the range (startColumn, endColumn). + for (let col = startColumn; col < endColumn; col++) { + kernelArray[row][col] = 1; + } + } + + return kernelArray; +}