Skip to content

Commit 8860359

Browse files
Add option to resize one dimension and maintain aspect ratio.
1 parent 7a0f77c commit 8860359

File tree

3 files changed

+75
-22
lines changed

3 files changed

+75
-22
lines changed

src/utils/core.js

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

22
/**
33
* @file Core utility functions/classes for Transformers.js.
4-
*
4+
*
55
* These are only used internally, meaning an end-user shouldn't
66
* need to access anything here.
7-
*
7+
*
88
* @module utils/core
99
*/
1010

@@ -46,7 +46,7 @@ export function escapeRegExp(string) {
4646
* Check if a value is a typed array.
4747
* @param {*} val The value to check.
4848
* @returns {boolean} True if the value is a `TypedArray`, false otherwise.
49-
*
49+
*
5050
* Adapted from https://stackoverflow.com/a/71091338/13989043
5151
*/
5252
export function isTypedArray(val) {
@@ -63,6 +63,15 @@ export function isIntegralNumber(x) {
6363
return Number.isInteger(x) || typeof x === 'bigint'
6464
}
6565

66+
/**
67+
* Determine if a provided width or height is nullish.
68+
* @param {*} x The value to check.
69+
* @returns {boolean} True if the value is `null`, `undefined` or `-1`, false otherwise.
70+
*/
71+
export function isNullishDimension(x) {
72+
return x === null || x === undefined || x === -1 || x === '-1';
73+
}
74+
6675
/**
6776
* Calculates the dimensions of a nested array.
6877
*
@@ -132,9 +141,9 @@ export function calculateReflectOffset(i, w) {
132141
}
133142

134143
/**
135-
*
136-
* @param {Object} o
137-
* @param {string[]} props
144+
*
145+
* @param {Object} o
146+
* @param {string[]} props
138147
* @returns {Object}
139148
*/
140149
export function pick(o, props) {
@@ -151,7 +160,7 @@ export function pick(o, props) {
151160
/**
152161
* Calculate the length of a string, taking multi-byte characters into account.
153162
* This mimics the behavior of Python's `len` function.
154-
* @param {string} s The string to calculate the length of.
163+
* @param {string} s The string to calculate the length of.
155164
* @returns {number} The length of the string.
156165
*/
157166
export function len(s) {

src/utils/image.js

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11

22
/**
3-
* @file Helper module for image processing.
4-
*
5-
* These functions and classes are only used internally,
3+
* @file Helper module for image processing.
4+
*
5+
* These functions and classes are only used internally,
66
* meaning an end-user shouldn't need to access anything here.
7-
*
7+
*
88
* @module utils/image
99
*/
1010

11+
import { isNullishDimension } from './core.js';
1112
import { getFile } from './hub.js';
1213
import { env } from '../env.js';
1314
import { Tensor } from './tensor.js';
@@ -91,7 +92,7 @@ export class RawImage {
9192
this.channels = channels;
9293
}
9394

94-
/**
95+
/**
9596
* Returns the size of the image (width, height).
9697
* @returns {[number, number]} The size of the image (width, height).
9798
*/
@@ -101,9 +102,9 @@ export class RawImage {
101102

102103
/**
103104
* Helper method for reading an image from a variety of input types.
104-
* @param {RawImage|string|URL} input
105+
* @param {RawImage|string|URL} input
105106
* @returns The image object.
106-
*
107+
*
107108
* **Example:** Read image from a URL.
108109
* ```javascript
109110
* let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg');
@@ -181,7 +182,7 @@ export class RawImage {
181182

182183
/**
183184
* Helper method to create a new Image from a tensor
184-
* @param {Tensor} tensor
185+
* @param {Tensor} tensor
185186
*/
186187
static fromTensor(tensor, channel_format = 'CHW') {
187188
if (tensor.dims.length !== 3) {
@@ -306,8 +307,8 @@ export class RawImage {
306307

307308
/**
308309
* Resize the image to the given dimensions. This method uses the canvas API to perform the resizing.
309-
* @param {number} width The width of the new image.
310-
* @param {number} height The height of the new image.
310+
* @param {number} width The width of the new image. `null` or `-1` will preserve the aspect ratio.
311+
* @param {number} height The height of the new image. `null` or `-1` will preserve the aspect ratio.
311312
* @param {Object} options Additional options for resizing.
312313
* @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use.
313314
* @returns {Promise<RawImage>} `this` to support chaining.
@@ -319,6 +320,18 @@ export class RawImage {
319320
// Ensure resample method is a string
320321
let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample;
321322

323+
// Calculate width / height to maintain aspect ratio, in the event that
324+
// the user passed a null value in.
325+
// This allows users to pass in something like `resize(320, null)` to
326+
// resize to 320 width, but maintain aspect ratio.
327+
if (isNullishDimension(width) && isNullishDimension(height)) {
328+
return this;
329+
} else if (isNullishDimension(width)) {
330+
width = (height / this.height) * this.width;
331+
} else if (isNullishDimension(height)) {
332+
height = (width / this.width) * this.height;
333+
}
334+
322335
if (BROWSER_ENV) {
323336
// TODO use `resample` in browser environment
324337

@@ -355,7 +368,7 @@ export class RawImage {
355368
case 'nearest':
356369
case 'bilinear':
357370
case 'bicubic':
358-
// Perform resizing using affine transform.
371+
// Perform resizing using affine transform.
359372
// This matches how the python Pillow library does it.
360373
img = img.affine([width / this.width, 0, 0, height / this.height], {
361374
interpolator: resampleMethod
@@ -368,7 +381,7 @@ export class RawImage {
368381
img = img.resize({
369382
width, height,
370383
fit: 'fill',
371-
kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3
384+
kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3
372385
});
373386
break;
374387

@@ -447,7 +460,7 @@ export class RawImage {
447460
// Create canvas object for this image
448461
const canvas = this.toCanvas();
449462

450-
// Create a new canvas of the desired size. This is needed since if the
463+
// Create a new canvas of the desired size. This is needed since if the
451464
// image is too small, we need to pad it with black pixels.
452465
const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d');
453466

@@ -495,7 +508,7 @@ export class RawImage {
495508
// Create canvas object for this image
496509
const canvas = this.toCanvas();
497510

498-
// Create a new canvas of the desired size. This is needed since if the
511+
// Create a new canvas of the desired size. This is needed since if the
499512
// image is too small, we need to pad it with black pixels.
500513
const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d');
501514

@@ -742,4 +755,4 @@ export class RawImage {
742755
}
743756
});
744757
}
745-
}
758+
}

tests/utils/utils.test.js

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { AutoProcessor, hamming, hanning, mel_filter_bank } from "../../src/transformers.js";
22
import { getFile } from "../../src/utils/hub.js";
3+
import { RawImage } from "../../src/utils/image.js";
34

45
import { MAX_TEST_EXECUTION_TIME } from "../init.js";
56
import { compare } from "../test_utils.js";
@@ -59,4 +60,34 @@ describe("Utilities", () => {
5960
expect(await data.text()).toBe("Hello, world!");
6061
});
6162
});
63+
64+
describe("Image utilities", () => {
65+
it("Read image from URL", async () => {
66+
const image = await RawImage.fromURL("https://picsum.photos/300/200");
67+
expect(image.width).toBe(300);
68+
expect(image.height).toBe(200);
69+
expect(image.channels).toBe(3);
70+
});
71+
72+
it("Can resize image", async () => {
73+
const image = await RawImage.fromURL("https://picsum.photos/300/200");
74+
const resized = await image.resize(150, 100);
75+
expect(resized.width).toBe(150);
76+
expect(resized.height).toBe(100);
77+
});
78+
79+
it("Can resize with aspect ratio", async () => {
80+
const image = await RawImage.fromURL("https://picsum.photos/300/200");
81+
const resized = await image.resize(150, null);
82+
expect(resized.width).toBe(150);
83+
expect(resized.height).toBe(100);
84+
});
85+
86+
it("Returns original image if width and height are null", async () => {
87+
const image = await RawImage.fromURL("https://picsum.photos/300/200");
88+
const resized = await image.resize(null, null);
89+
expect(resized.width).toBe(300);
90+
expect(resized.height).toBe(200);
91+
});
92+
});
6293
});

0 commit comments

Comments
 (0)