|
1 | 1 | const { ONNX } = require('./backends/onnx.js'); |
2 | 2 |
|
| 3 | +const { interpolate: interpolate_data, transpose: transpose_data } = require('./math_utils.js'); |
| 4 | + |
3 | 5 |
|
4 | 6 | /** |
5 | | - * @typedef {Int8Array | Uint8Array | Uint8ClampedArray | Int16Array | Uint16Array | Int32Array | Uint32Array | Float32Array | Float64Array | BigInt64Array | BigUint64Array} TypedArray |
| 7 | + * @typedef {import('./math_utils.js').AnyTypedArray} AnyTypedArray |
6 | 8 | */ |
7 | 9 |
|
| 10 | +const ONNXTensor = ONNX.Tensor; |
| 11 | + |
8 | 12 | // TODO: fix error below |
9 | | -class Tensor extends ONNX.Tensor { |
| 13 | +class Tensor extends ONNXTensor { |
10 | 14 | /** |
11 | 15 | * Create a new Tensor or copy an existing Tensor. |
12 | | - * @param {[string, Array|TypedArray, number[]]|[ONNX.Tensor]} args |
| 16 | + * @param {[string, Array|AnyTypedArray, number[]]|[ONNXTensor]} args |
13 | 17 | */ |
14 | 18 | constructor(...args) { |
15 | 19 | if (args[0] instanceof ONNX.Tensor) { |
@@ -191,43 +195,6 @@ function transpose(tensor, axes) { |
191 | 195 | return new Tensor(tensor.type, transposedData, shape); |
192 | 196 | } |
193 | 197 |
|
194 | | -/** |
195 | | - * Helper method to transpose a TypedArray directly |
196 | | - * @param {TypedArray} array |
197 | | - * @param {number[]} dims |
198 | | - * @param {number[]} axes |
199 | | - * @returns {[TypedArray, number[]]} The transposed array and the new shape. |
200 | | - */ |
201 | | -function transpose_data(array, dims, axes) { |
202 | | - // Calculate the new shape of the transposed array |
203 | | - // and the stride of the original array |
204 | | - const shape = new Array(axes.length); |
205 | | - const stride = new Array(axes.length); |
206 | | - |
207 | | - for (let i = axes.length - 1, s = 1; i >= 0; --i) { |
208 | | - stride[i] = s; |
209 | | - shape[i] = dims[axes[i]]; |
210 | | - s *= shape[i]; |
211 | | - } |
212 | | - |
213 | | - // Precompute inverse mapping of stride |
214 | | - const invStride = axes.map((_, i) => stride[axes.indexOf(i)]); |
215 | | - |
216 | | - // Create the transposed array with the new shape |
217 | | - const transposedData = new array.constructor(array.length); |
218 | | - |
219 | | - // Transpose the original array to the new array |
220 | | - for (let i = 0; i < array.length; ++i) { |
221 | | - let newIndex = 0; |
222 | | - for (let j = dims.length - 1, k = i; j >= 0; --j) { |
223 | | - newIndex += (k % dims[j]) * invStride[j]; |
224 | | - k = Math.floor(k / dims[j]); |
225 | | - } |
226 | | - transposedData[newIndex] = array[i]; |
227 | | - } |
228 | | - |
229 | | - return [transposedData, shape]; |
230 | | -} |
231 | 198 |
|
232 | 199 | /** |
233 | 200 | * Concatenates an array of tensors along the 0th dimension. |
@@ -275,76 +242,20 @@ function cat(tensors) { |
275 | 242 | * @returns {Tensor} - The interpolated tensor. |
276 | 243 | */ |
277 | 244 | function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) { |
278 | | - // TODO use mode and align_corners |
279 | 245 |
|
280 | 246 | // Input image dimensions |
281 | 247 | const in_channels = input.dims.at(-3) ?? 1; |
282 | 248 | const in_height = input.dims.at(-2); |
283 | 249 | const in_width = input.dims.at(-1); |
284 | 250 |
|
285 | | - // Output image dimensions |
286 | | - const x_scale = out_width / in_width; |
287 | | - const y_scale = out_height / in_height; |
288 | | - |
289 | | - // Output image |
290 | | - const out_img = new input.data.constructor(out_height * out_width * in_channels); |
291 | | - |
292 | | - // Pre-calculate strides |
293 | | - const inStride = in_height * in_width; |
294 | | - const outStride = out_height * out_width; |
295 | | - |
296 | | - for (let i = 0; i < out_height; ++i) { |
297 | | - for (let j = 0; j < out_width; ++j) { |
298 | | - // Calculate output offset |
299 | | - const outOffset = i * out_width + j; |
300 | | - |
301 | | - // Calculate input pixel coordinates |
302 | | - const x = (j + 0.5) / x_scale - 0.5; |
303 | | - const y = (i + 0.5) / y_scale - 0.5; |
304 | | - |
305 | | - // Calculate the four nearest input pixels |
306 | | - // We also check if the input pixel coordinates are within the image bounds |
307 | | - let x1 = Math.floor(x); |
308 | | - let y1 = Math.floor(y); |
309 | | - const x2 = Math.min(x1 + 1, in_width - 1); |
310 | | - const y2 = Math.min(y1 + 1, in_height - 1); |
311 | | - |
312 | | - x1 = Math.max(x1, 0); |
313 | | - y1 = Math.max(y1, 0); |
314 | | - |
315 | | - |
316 | | - // Calculate the fractional distances between the input pixel and the four nearest pixels |
317 | | - const s = x - x1; |
318 | | - const t = y - y1; |
319 | | - |
320 | | - // Perform bilinear interpolation |
321 | | - const w1 = (1 - s) * (1 - t); |
322 | | - const w2 = s * (1 - t); |
323 | | - const w3 = (1 - s) * t; |
324 | | - const w4 = s * t; |
325 | | - |
326 | | - // Calculate the four nearest input pixel indices |
327 | | - const yStride = y1 * in_width; |
328 | | - const xStride = y2 * in_width; |
329 | | - const idx1 = yStride + x1; |
330 | | - const idx2 = yStride + x2; |
331 | | - const idx3 = xStride + x1; |
332 | | - const idx4 = xStride + x2; |
333 | | - |
334 | | - for (let k = 0; k < in_channels; ++k) { |
335 | | - // Calculate channel offset |
336 | | - const cOffset = k * inStride; |
337 | | - |
338 | | - out_img[k * outStride + outOffset] = |
339 | | - w1 * input.data[cOffset + idx1] + |
340 | | - w2 * input.data[cOffset + idx2] + |
341 | | - w3 * input.data[cOffset + idx3] + |
342 | | - w4 * input.data[cOffset + idx4]; |
343 | | - } |
344 | | - } |
345 | | - } |
346 | | - |
347 | | - return new Tensor(input.type, out_img, [in_channels, out_height, out_width]); |
| 251 | + let output = interpolate_data( |
| 252 | + input.data, |
| 253 | + [in_channels, in_height, in_width], |
| 254 | + [out_height, out_width], |
| 255 | + mode, |
| 256 | + align_corners |
| 257 | + ); |
| 258 | + return new Tensor(input.type, output, [in_channels, out_height, out_width]); |
348 | 259 | } |
349 | 260 |
|
350 | 261 | module.exports = { |
|
0 commit comments