@@ -111,14 +111,20 @@ def create_model(
111111 return model
112112
113113
114- def create_preprocessing (model_name : str , dtype : Optional [str ] = None ) -> Callable :
114+ def create_preprocessing (
115+ model_name : str ,
116+ * ,
117+ in_channels : Optional [float ] = None ,
118+ dtype : Optional [str ] = None ,
119+ ) -> Callable :
115120 """
116121 Creates a function to preprocess images for a particular model.
117122
118123 The input to the preprocessing function is assumed to be values in range [0, 255].
119124
120125 Args:
121126 model_name: Model for which to create preprocessing function.
127+ in_channels: Number of input channels to model
122128 dtype: Output dtype.
123129
124130 Returns:
@@ -130,9 +136,22 @@ def create_preprocessing(model_name: str, dtype: Optional[str] = None) -> Callab
130136 cfg = model_config (model_name )
131137 dtype = dtype or tf .keras .backend .floatx ()
132138
139+ def _adapt_vector (v , n ):
140+ """Adapts vector v to length n by repeating as necessary."""
141+ v = tf .convert_to_tensor (v , dtype = dtype )
142+ m = tf .shape (v )[0 ]
143+ nb_repeats = n // m + 1
144+ v = tf .tile (v , [nb_repeats ])
145+ v = v [:n ]
146+ return v
147+
148+ in_channels = in_channels or cfg .in_channels
149+ mean = _adapt_vector (cfg .mean , in_channels )
150+ std = _adapt_vector (cfg .std , in_channels )
151+
133152 def _preprocess (img : tf .Tensor ) -> tf .Tensor :
134153 img = tf .cast (img , dtype = dtype ) / 255.0
135- img = (img - cfg . mean ) / cfg . std
154+ img = (img - mean ) / std
136155 return img
137156
138157 return _preprocess
0 commit comments