Skip to content

Commit dd4f62c

Browse files
Adapted preprocessing. (#60)
1 parent 6bfa612 commit dd4f62c

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

tests/models/test_factory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ def test_model_path(model_name):
125125

126126

127127
@pytest.mark.parametrize("model_name", TEST_ARCHITECTURES)
128-
@pytest.mark.parametrize("input_shape", [(8, 8, 3), (1, 4, 4, 3)])
128+
@pytest.mark.parametrize("input_size", [(8, 8), (1, 4, 4)])
129+
@pytest.mark.parametrize("in_channels", [1, 3, 5, 6])
129130
@pytest.mark.parametrize("dtype", ["float32", "float16"])
130-
def test_preprocessing(model_name, input_shape, dtype):
131+
def test_preprocessing(model_name, input_size, in_channels, dtype):
132+
input_shape = (*input_size, in_channels)
131133
img = tf.ones(input_shape, dtype)
132-
preprocess = create_preprocessing(model_name, dtype)
134+
preprocess = create_preprocessing(model_name, in_channels=in_channels, dtype=dtype)
133135
img = preprocess(img)
134136
assert img.shape == input_shape
135137
assert img.dtype == dtype

tfimm/models/factory.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,20 @@ def create_model(
111111
return model
112112

113113

114-
def create_preprocessing(model_name: str, dtype: Optional[str] = None) -> Callable:
114+
def create_preprocessing(
115+
model_name: str,
116+
*,
117+
in_channels: Optional[float] = None,
118+
dtype: Optional[str] = None,
119+
) -> Callable:
115120
"""
116121
Creates a function to preprocess images for a particular model.
117122
118123
The input to the preprocessing function is assumed to be values in range [0, 255].
119124
120125
Args:
121126
model_name: Model for which to create preprocessing function.
127+
in_channels: Number of input channels to model
122128
dtype: Output dtype.
123129
124130
Returns:
@@ -130,9 +136,22 @@ def create_preprocessing(model_name: str, dtype: Optional[str] = None) -> Callab
130136
cfg = model_config(model_name)
131137
dtype = dtype or tf.keras.backend.floatx()
132138

139+
def _adapt_vector(v, n):
140+
"""Adapts vector v to length n by repeating as necessary."""
141+
v = tf.convert_to_tensor(v, dtype=dtype)
142+
m = tf.shape(v)[0]
143+
nb_repeats = n // m + 1
144+
v = tf.tile(v, [nb_repeats])
145+
v = v[:n]
146+
return v
147+
148+
in_channels = in_channels or cfg.in_channels
149+
mean = _adapt_vector(cfg.mean, in_channels)
150+
std = _adapt_vector(cfg.std, in_channels)
151+
133152
def _preprocess(img: tf.Tensor) -> tf.Tensor:
134153
img = tf.cast(img, dtype=dtype) / 255.0
135-
img = (img - cfg.mean) / cfg.std
154+
img = (img - mean) / std
136155
return img
137156

138157
return _preprocess

0 commit comments

Comments
 (0)