Skip to content

Commit fdfdd62

Browse files
committed
Move interpolate and transpose method logic from tensor_utils.js to math_utils.js
1 parent c6f5f16 commit fdfdd62

File tree

2 files changed

+139
-104
lines changed

2 files changed

+139
-104
lines changed

src/math_utils.js

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
2+
/**
3+
* @typedef {Int8Array | Uint8Array | Uint8ClampedArray | Int16Array | Uint16Array | Int32Array | Uint32Array | Float32Array | Float64Array} TypedArray
4+
* @typedef {BigInt64Array | BigUint64Array} BigTypedArray
5+
* @typedef {TypedArray | BigTypedArray} AnyTypedArray
6+
*/
7+
8+
/**
9+
* @param {TypedArray} input
10+
*/
11+
function interpolate(input, [in_channels, in_height, in_width], [out_height, out_width], mode = 'bilinear', align_corners = false) {
12+
// TODO use mode and align_corners
13+
14+
// Output image dimensions
15+
const x_scale = out_width / in_width;
16+
const y_scale = out_height / in_height;
17+
18+
// Output image
19+
// @ts-ignore
20+
const out_img = new input.constructor(out_height * out_width * in_channels);
21+
22+
// Pre-calculate strides
23+
const inStride = in_height * in_width;
24+
const outStride = out_height * out_width;
25+
26+
for (let i = 0; i < out_height; ++i) {
27+
for (let j = 0; j < out_width; ++j) {
28+
// Calculate output offset
29+
const outOffset = i * out_width + j;
30+
31+
// Calculate input pixel coordinates
32+
const x = (j + 0.5) / x_scale - 0.5;
33+
const y = (i + 0.5) / y_scale - 0.5;
34+
35+
// Calculate the four nearest input pixels
36+
// We also check if the input pixel coordinates are within the image bounds
37+
let x1 = Math.floor(x);
38+
let y1 = Math.floor(y);
39+
const x2 = Math.min(x1 + 1, in_width - 1);
40+
const y2 = Math.min(y1 + 1, in_height - 1);
41+
42+
x1 = Math.max(x1, 0);
43+
y1 = Math.max(y1, 0);
44+
45+
46+
// Calculate the fractional distances between the input pixel and the four nearest pixels
47+
const s = x - x1;
48+
const t = y - y1;
49+
50+
// Perform bilinear interpolation
51+
const w1 = (1 - s) * (1 - t);
52+
const w2 = s * (1 - t);
53+
const w3 = (1 - s) * t;
54+
const w4 = s * t;
55+
56+
// Calculate the four nearest input pixel indices
57+
const yStride = y1 * in_width;
58+
const xStride = y2 * in_width;
59+
const idx1 = yStride + x1;
60+
const idx2 = yStride + x2;
61+
const idx3 = xStride + x1;
62+
const idx4 = xStride + x2;
63+
64+
for (let k = 0; k < in_channels; ++k) {
65+
// Calculate channel offset
66+
const cOffset = k * inStride;
67+
68+
out_img[k * outStride + outOffset] =
69+
w1 * input[cOffset + idx1] +
70+
w2 * input[cOffset + idx2] +
71+
w3 * input[cOffset + idx3] +
72+
w4 * input[cOffset + idx4];
73+
}
74+
}
75+
}
76+
77+
return out_img;
78+
}
79+
80+
81+
/**
82+
* Helper method to transpose a AnyTypedArray directly
83+
* @param {T} array
84+
* @template {AnyTypedArray} T
85+
* @param {number[]} dims
86+
* @param {number[]} axes
87+
* @returns {[T, number[]]} The transposed array and the new shape.
88+
*/
89+
function transpose_data(array, dims, axes) {
90+
// Calculate the new shape of the transposed array
91+
// and the stride of the original array
92+
const shape = new Array(axes.length);
93+
const stride = new Array(axes.length);
94+
95+
for (let i = axes.length - 1, s = 1; i >= 0; --i) {
96+
stride[i] = s;
97+
shape[i] = dims[axes[i]];
98+
s *= shape[i];
99+
}
100+
101+
// Precompute inverse mapping of stride
102+
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);
103+
104+
// Create the transposed array with the new shape
105+
// @ts-ignore
106+
const transposedData = new array.constructor(array.length);
107+
108+
// Transpose the original array to the new array
109+
for (let i = 0; i < array.length; ++i) {
110+
let newIndex = 0;
111+
for (let j = dims.length - 1, k = i; j >= 0; --j) {
112+
newIndex += (k % dims[j]) * invStride[j];
113+
k = Math.floor(k / dims[j]);
114+
}
115+
transposedData[newIndex] = array[i];
116+
}
117+
118+
return [transposedData, shape];
119+
}
120+
121+
module.exports = {
122+
interpolate,
123+
transpose: transpose_data,
124+
}

src/tensor_utils.js

Lines changed: 15 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
const { ONNX } = require('./backends/onnx.js');
22

3+
const { interpolate: interpolate_data, transpose: transpose_data } = require('./math_utils.js');
4+
35

46
/**
5-
* @typedef {Int8Array | Uint8Array | Uint8ClampedArray | Int16Array | Uint16Array | Int32Array | Uint32Array | Float32Array | Float64Array | BigInt64Array | BigUint64Array} TypedArray
7+
* @typedef {import('./math_utils.js').AnyTypedArray} AnyTypedArray
68
*/
79

10+
const ONNXTensor = ONNX.Tensor;
11+
812
// TODO: fix error below
9-
class Tensor extends ONNX.Tensor {
13+
class Tensor extends ONNXTensor {
1014
/**
1115
* Create a new Tensor or copy an existing Tensor.
12-
* @param {[string, Array|TypedArray, number[]]|[ONNX.Tensor]} args
16+
* @param {[string, Array|AnyTypedArray, number[]]|[ONNXTensor]} args
1317
*/
1418
constructor(...args) {
1519
if (args[0] instanceof ONNX.Tensor) {
@@ -191,43 +195,6 @@ function transpose(tensor, axes) {
191195
return new Tensor(tensor.type, transposedData, shape);
192196
}
193197

194-
/**
195-
* Helper method to transpose a TypedArray directly
196-
* @param {TypedArray} array
197-
* @param {number[]} dims
198-
* @param {number[]} axes
199-
* @returns {[TypedArray, number[]]} The transposed array and the new shape.
200-
*/
201-
function transpose_data(array, dims, axes) {
202-
// Calculate the new shape of the transposed array
203-
// and the stride of the original array
204-
const shape = new Array(axes.length);
205-
const stride = new Array(axes.length);
206-
207-
for (let i = axes.length - 1, s = 1; i >= 0; --i) {
208-
stride[i] = s;
209-
shape[i] = dims[axes[i]];
210-
s *= shape[i];
211-
}
212-
213-
// Precompute inverse mapping of stride
214-
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);
215-
216-
// Create the transposed array with the new shape
217-
const transposedData = new array.constructor(array.length);
218-
219-
// Transpose the original array to the new array
220-
for (let i = 0; i < array.length; ++i) {
221-
let newIndex = 0;
222-
for (let j = dims.length - 1, k = i; j >= 0; --j) {
223-
newIndex += (k % dims[j]) * invStride[j];
224-
k = Math.floor(k / dims[j]);
225-
}
226-
transposedData[newIndex] = array[i];
227-
}
228-
229-
return [transposedData, shape];
230-
}
231198

232199
/**
233200
* Concatenates an array of tensors along the 0th dimension.
@@ -275,76 +242,20 @@ function cat(tensors) {
275242
* @returns {Tensor} - The interpolated tensor.
276243
*/
277244
function interpolate(input, [out_height, out_width], mode = 'bilinear', align_corners = false) {
278-
// TODO use mode and align_corners
279245

280246
// Input image dimensions
281247
const in_channels = input.dims.at(-3) ?? 1;
282248
const in_height = input.dims.at(-2);
283249
const in_width = input.dims.at(-1);
284250

285-
// Output image dimensions
286-
const x_scale = out_width / in_width;
287-
const y_scale = out_height / in_height;
288-
289-
// Output image
290-
const out_img = new input.data.constructor(out_height * out_width * in_channels);
291-
292-
// Pre-calculate strides
293-
const inStride = in_height * in_width;
294-
const outStride = out_height * out_width;
295-
296-
for (let i = 0; i < out_height; ++i) {
297-
for (let j = 0; j < out_width; ++j) {
298-
// Calculate output offset
299-
const outOffset = i * out_width + j;
300-
301-
// Calculate input pixel coordinates
302-
const x = (j + 0.5) / x_scale - 0.5;
303-
const y = (i + 0.5) / y_scale - 0.5;
304-
305-
// Calculate the four nearest input pixels
306-
// We also check if the input pixel coordinates are within the image bounds
307-
let x1 = Math.floor(x);
308-
let y1 = Math.floor(y);
309-
const x2 = Math.min(x1 + 1, in_width - 1);
310-
const y2 = Math.min(y1 + 1, in_height - 1);
311-
312-
x1 = Math.max(x1, 0);
313-
y1 = Math.max(y1, 0);
314-
315-
316-
// Calculate the fractional distances between the input pixel and the four nearest pixels
317-
const s = x - x1;
318-
const t = y - y1;
319-
320-
// Perform bilinear interpolation
321-
const w1 = (1 - s) * (1 - t);
322-
const w2 = s * (1 - t);
323-
const w3 = (1 - s) * t;
324-
const w4 = s * t;
325-
326-
// Calculate the four nearest input pixel indices
327-
const yStride = y1 * in_width;
328-
const xStride = y2 * in_width;
329-
const idx1 = yStride + x1;
330-
const idx2 = yStride + x2;
331-
const idx3 = xStride + x1;
332-
const idx4 = xStride + x2;
333-
334-
for (let k = 0; k < in_channels; ++k) {
335-
// Calculate channel offset
336-
const cOffset = k * inStride;
337-
338-
out_img[k * outStride + outOffset] =
339-
w1 * input.data[cOffset + idx1] +
340-
w2 * input.data[cOffset + idx2] +
341-
w3 * input.data[cOffset + idx3] +
342-
w4 * input.data[cOffset + idx4];
343-
}
344-
}
345-
}
346-
347-
return new Tensor(input.type, out_img, [in_channels, out_height, out_width]);
251+
let output = interpolate_data(
252+
input.data,
253+
[in_channels, in_height, in_width],
254+
[out_height, out_width],
255+
mode,
256+
align_corners
257+
);
258+
return new Tensor(input.type, output, [in_channels, out_height, out_width]);
348259
}
349260

350261
module.exports = {

0 commit comments

Comments
 (0)