Skip to content

Commit 0b1be4e

Browse files
committed
Update image preprocessing method
1 parent 8ab68cb commit 0b1be4e

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

src/processors.js

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,7 @@ class ImageFeatureExtractor extends FeatureExtractor {
6969
super(config);
7070

7171
this.image_mean = this.config.image_mean;
72-
if (!Array.isArray(this.image_mean)) {
73-
this.image_mean = new Array(3).fill(this.image_mean);
74-
}
75-
7672
this.image_std = this.config.image_std;
77-
if (!Array.isArray(this.image_std)) {
78-
this.image_std = new Array(3).fill(this.image_std);
79-
}
8073

8174
this.resample = this.config.resample ?? 2; // 2 => bilinear
8275
this.do_rescale = this.config.do_rescale ?? true;
@@ -100,11 +93,18 @@ class ImageFeatureExtractor extends FeatureExtractor {
10093
*/
10194
async preprocess(image) {
10295

96+
// First, convert image to RGB if specified in config.
97+
if (this.do_convert_rgb) {
98+
image = image.rgb();
99+
}
100+
103101
const srcWidth = image.width; // original width
104102
const srcHeight = image.height; // original height
105103

106-
// First, resize all images
104+
// Next, resize all images
107105
if (this.do_resize) {
106+
// TODO:
107+
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.
108108

109109
// `this.size` comes in many forms, so we need to handle them all here:
110110
// 1. `this.size` is an integer, in which case we resize the image to be a square
@@ -153,9 +153,19 @@ class ImageFeatureExtractor extends FeatureExtractor {
153153
}
154154
}
155155

156-
if (this.do_convert_rgb) {
157-
// Convert image to RGB
158-
image = image.rgb();
156+
if (this.do_center_crop) {
157+
158+
let crop_width;
159+
let crop_height;
160+
if (Number.isInteger(this.crop_size)) {
161+
crop_width = this.crop_size;
162+
crop_height = this.crop_size;
163+
} else {
164+
crop_width = this.crop_size.width;
165+
crop_height = this.crop_size.height;
166+
}
167+
168+
image = await image.center_crop(crop_width, crop_height);
159169
}
160170

161171
const pixelData = Float32Array.from(image.data);
@@ -167,14 +177,29 @@ class ImageFeatureExtractor extends FeatureExtractor {
167177
}
168178

169179
if (this.do_normalize) {
170-
for (let i = 0; i < pixelData.length; i += 3) {
171-
for (let j = 0; j < 3; ++j) {
180+
let image_mean = this.image_mean;
181+
if (!Array.isArray(this.image_mean)) {
182+
image_mean = new Array(image.channels).fill(image_mean);
183+
}
184+
185+
let image_std = this.image_std;
186+
if (!Array.isArray(this.image_std)) {
187+
image_std = new Array(image.channels).fill(image_mean);
188+
}
189+
190+
if (image_mean.length !== image.channels || image_std.length !== image.channels) {
191+
throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`);
192+
}
193+
194+
for (let i = 0; i < pixelData.length; i += image.channels) {
195+
for (let j = 0; j < image.channels; ++j) {
172196
pixelData[i + j] = (pixelData[i + j] - this.image_mean[j]) / this.image_std[j];
173197
}
174198
}
175199
}
176200

177-
let imgDims = [image.height, image.width, 3];
201+
// convert to channel dimension format:
202+
let imgDims = [image.height, image.width, image.channels];
178203
let img = new Tensor('float32', pixelData, imgDims);
179204
let transposed = transpose(img, [2, 0, 1]); // hwc -> chw
180205

0 commit comments

Comments
 (0)