@@ -69,14 +69,7 @@ class ImageFeatureExtractor extends FeatureExtractor {
6969 super ( config ) ;
7070
7171 this . image_mean = this . config . image_mean ;
72- if ( ! Array . isArray ( this . image_mean ) ) {
73- this . image_mean = new Array ( 3 ) . fill ( this . image_mean ) ;
74- }
75-
7672 this . image_std = this . config . image_std ;
77- if ( ! Array . isArray ( this . image_std ) ) {
78- this . image_std = new Array ( 3 ) . fill ( this . image_std ) ;
79- }
8073
8174 this . resample = this . config . resample ?? 2 ; // 2 => bilinear
8275 this . do_rescale = this . config . do_rescale ?? true ;
@@ -100,11 +93,18 @@ class ImageFeatureExtractor extends FeatureExtractor {
10093 */
10194 async preprocess ( image ) {
10295
96+ // First, convert image to RGB if specified in config.
97+ if ( this . do_convert_rgb ) {
98+ image = image . rgb ( ) ;
99+ }
100+
103101 const srcWidth = image . width ; // original width
104102 const srcHeight = image . height ; // original height
105103
106- // First , resize all images
104+ // Next , resize all images
107105 if ( this . do_resize ) {
106+ // TODO:
107+ // For efficiency reasons, it might be best to merge the resize and center crop operations into one.
108108
109109 // `this.size` comes in many forms, so we need to handle them all here:
110110 // 1. `this.size` is an integer, in which case we resize the image to be a square
@@ -153,9 +153,19 @@ class ImageFeatureExtractor extends FeatureExtractor {
153153 }
154154 }
155155
156- if ( this . do_convert_rgb ) {
157- // Convert image to RGB
158- image = image . rgb ( ) ;
156+ if ( this . do_center_crop ) {
157+
158+ let crop_width ;
159+ let crop_height ;
160+ if ( Number . isInteger ( this . crop_size ) ) {
161+ crop_width = this . crop_size ;
162+ crop_height = this . crop_size ;
163+ } else {
164+ crop_width = this . crop_size . width ;
165+ crop_height = this . crop_size . height ;
166+ }
167+
168+ image = await image . center_crop ( crop_width , crop_height ) ;
159169 }
160170
161171 const pixelData = Float32Array . from ( image . data ) ;
@@ -167,14 +177,29 @@ class ImageFeatureExtractor extends FeatureExtractor {
167177 }
168178
169179 if ( this . do_normalize ) {
170- for ( let i = 0 ; i < pixelData . length ; i += 3 ) {
171- for ( let j = 0 ; j < 3 ; ++ j ) {
180+ let image_mean = this . image_mean ;
181+ if ( ! Array . isArray ( this . image_mean ) ) {
182+ image_mean = new Array ( image . channels ) . fill ( image_mean ) ;
183+ }
184+
185+ let image_std = this . image_std ;
186+ if ( ! Array . isArray ( this . image_std ) ) {
187+ image_std = new Array ( image . channels ) . fill ( image_mean ) ;
188+ }
189+
190+ if ( image_mean . length !== image . channels || image_std . length !== image . channels ) {
191+ throw new Error ( `When set to arrays, the length of \`image_mean\` (${ image_mean . length } ) and \`image_std\` (${ image_std . length } ) must match the number of channels in the image (${ image . channels } ).` ) ;
192+ }
193+
194+ for ( let i = 0 ; i < pixelData . length ; i += image . channels ) {
195+ for ( let j = 0 ; j < image . channels ; ++ j ) {
172196 pixelData [ i + j ] = ( pixelData [ i + j ] - this . image_mean [ j ] ) / this . image_std [ j ] ;
173197 }
174198 }
175199 }
176200
177- let imgDims = [ image . height , image . width , 3 ] ;
201+ // convert to channel dimension format:
202+ let imgDims = [ image . height , image . width , image . channels ] ;
178203 let img = new Tensor ( 'float32' , pixelData , imgDims ) ;
179204 let transposed = transpose ( img , [ 2 , 0 , 1 ] ) ; // hwc -> chw
180205
0 commit comments