Skip to content

Commit 45c98ec

Browse files
Add ops.image.scale_and_translate. (#21577)
* Add `ops.image.scale_and_translate`. * Fix op tests. * Fix np resize.
1 parent 0387d30 commit 45c98ec

File tree

9 files changed

+712
-109
lines changed

9 files changed

+712
-109
lines changed

keras/api/_tf_keras/keras/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from keras.src.ops.image import resize as resize
1717
from keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale
1818
from keras.src.ops.image import rgb_to_hsv as rgb_to_hsv
19+
from keras.src.ops.image import scale_and_translate as scale_and_translate

keras/api/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from keras.src.ops.image import resize as resize
1717
from keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale
1818
from keras.src.ops.image import rgb_to_hsv as rgb_to_hsv
19+
from keras.src.ops.image import scale_and_translate as scale_and_translate

keras/src/backend/jax/image.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@
1414
"lanczos5",
1515
"bicubic",
1616
)
17+
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
18+
"nearest": 0,
19+
"bilinear": 1,
20+
}
21+
AFFINE_TRANSFORM_FILL_MODES = {
22+
"constant",
23+
"nearest",
24+
"wrap",
25+
"mirror",
26+
"reflect",
27+
}
28+
MAP_COORDINATES_FILL_MODES = {
29+
"constant",
30+
"nearest",
31+
"wrap",
32+
"mirror",
33+
"reflect",
34+
}
35+
SCALE_AND_TRANSLATE_METHODS = {
36+
"linear",
37+
"bilinear",
38+
"trilinear",
39+
"cubic",
40+
"bicubic",
41+
"tricubic",
42+
"lanczos3",
43+
"lanczos5",
44+
}
1745

1846

1947
def rgb_to_grayscale(images, data_format=None):
@@ -372,19 +400,6 @@ def resize(
372400
)
373401

374402

375-
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
376-
"nearest": 0,
377-
"bilinear": 1,
378-
}
379-
AFFINE_TRANSFORM_FILL_MODES = {
380-
"constant",
381-
"nearest",
382-
"wrap",
383-
"mirror",
384-
"reflect",
385-
}
386-
387-
388403
def affine_transform(
389404
images,
390405
transform,
@@ -483,15 +498,6 @@ def affine_transform(
483498
return affined
484499

485500

486-
MAP_COORDINATES_FILL_MODES = {
487-
"constant",
488-
"nearest",
489-
"wrap",
490-
"mirror",
491-
"reflect",
492-
}
493-
494-
495501
def perspective_transform(
496502
images,
497503
start_points,
@@ -545,7 +551,7 @@ def perspective_transform(
545551
if data_format == "channels_first":
546552
images = jnp.transpose(images, (0, 2, 3, 1))
547553

548-
batch_size, height, width, channels = images.shape
554+
_, height, width, _ = images.shape
549555
transforms = compute_homography_matrix(
550556
jnp.asarray(start_points, dtype="float32"),
551557
jnp.asarray(end_points, dtype="float32"),
@@ -859,3 +865,31 @@ def elastic_transform(
859865
transformed_images = transformed_images.astype(input_dtype)
860866

861867
return transformed_images
868+
869+
870+
def scale_and_translate(
871+
images,
872+
output_shape,
873+
scale,
874+
translation,
875+
spatial_dims,
876+
method,
877+
antialias=True,
878+
):
879+
if method not in SCALE_AND_TRANSLATE_METHODS:
880+
raise ValueError(
881+
"Invalid value for argument `method`. Expected of one "
882+
f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}"
883+
)
884+
images = convert_to_tensor(images)
885+
scale = convert_to_tensor(scale)
886+
translation = convert_to_tensor(translation)
887+
return jax.image.scale_and_translate(
888+
images,
889+
output_shape,
890+
spatial_dims,
891+
scale,
892+
translation,
893+
method,
894+
antialias,
895+
)

keras/src/backend/numpy/image.py

Lines changed: 98 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@
1313
"lanczos5",
1414
"bicubic",
1515
)
16+
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
17+
"nearest": 0,
18+
"bilinear": 1,
19+
}
20+
AFFINE_TRANSFORM_FILL_MODES = {
21+
"constant",
22+
"nearest",
23+
"wrap",
24+
"mirror",
25+
"reflect",
26+
}
27+
MAP_COORDINATES_FILL_MODES = {
28+
"constant",
29+
"nearest",
30+
"wrap",
31+
"mirror",
32+
"reflect",
33+
}
34+
SCALE_AND_TRANSLATE_METHODS = {
35+
"linear",
36+
"bilinear",
37+
"trilinear",
38+
"cubic",
39+
"bicubic",
40+
"tricubic",
41+
"lanczos3",
42+
"lanczos5",
43+
}
1644

1745

1846
def rgb_to_grayscale(images, data_format=None):
@@ -367,7 +395,7 @@ def resize(
367395
return _resize(images, size, method=interpolation, antialias=antialias)
368396

369397

370-
def compute_weight_mat(
398+
def _compute_weight_mat(
371399
input_size, output_size, scale, translation, kernel, antialias
372400
):
373401
dtype = np.result_type(scale, translation)
@@ -410,32 +438,11 @@ def compute_weight_mat(
410438

411439

412440
def _resize(image, shape, method, antialias):
413-
def _fill_triangle_kernel(x):
414-
return np.maximum(0, 1 - np.abs(x))
415-
416-
def _fill_keys_cubic_kernel(x):
417-
out = ((1.5 * x - 2.5) * x) * x + 1.0
418-
out = np.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out)
419-
return np.where(x >= 2.0, 0.0, out)
420-
421-
def _fill_lanczos_kernel(radius, x):
422-
y = radius * np.sin(np.pi * x) * np.sin(np.pi * x / radius)
423-
out = np.where(
424-
x > 1e-3, np.divide(y, np.where(x != 0, np.pi**2 * x**2, 1)), 1
425-
)
426-
return np.where(x > radius, 0.0, out)
427-
428441
if method == "nearest":
429442
return _resize_nearest(image, shape)
430-
elif method == "bilinear":
431-
kernel = _fill_triangle_kernel
432-
elif method == "lanczos3":
433-
kernel = lambda x: _fill_lanczos_kernel(3.0, x)
434-
elif method == "lanczos5":
435-
kernel = lambda x: _fill_lanczos_kernel(5.0, x)
436-
elif method == "bicubic":
437-
kernel = _fill_keys_cubic_kernel
438443
else:
444+
kernel = _kernels.get(method, None)
445+
if kernel is None:
439446
raise ValueError("Unknown resize method")
440447

441448
spatial_dims = tuple(
@@ -473,6 +480,34 @@ def _resize_nearest(x, output_shape):
473480
return x
474481

475482

483+
def _fill_triangle_kernel(x):
484+
return np.maximum(0, 1 - np.abs(x))
485+
486+
487+
def _fill_keys_cubic_kernel(x):
488+
out = ((1.5 * x - 2.5) * x) * x + 1.0
489+
out = np.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out)
490+
return np.where(x >= 2.0, 0.0, out)
491+
492+
493+
def _fill_lanczos_kernel(radius, x):
494+
y = radius * np.sin(np.pi * x) * np.sin(np.pi * x / radius)
495+
out = np.where(
496+
x > 1e-3, np.divide(y, np.where(x != 0, np.pi**2 * x**2, 1)), 1
497+
)
498+
return np.where(x > radius, 0.0, out)
499+
500+
501+
_kernels = {
502+
"linear": _fill_triangle_kernel,
503+
"bilinear": _fill_triangle_kernel, # For `resize`.
504+
"cubic": _fill_keys_cubic_kernel,
505+
"bicubic": _fill_keys_cubic_kernel, # For `resize`.
506+
"lanczos3": lambda x: _fill_lanczos_kernel(3.0, x),
507+
"lanczos5": lambda x: _fill_lanczos_kernel(5.0, x),
508+
}
509+
510+
476511
def _scale_and_translate(
477512
x, output_shape, spatial_dims, scale, translation, kernel, antialias
478513
):
@@ -492,9 +527,9 @@ def _scale_and_translate(
492527
d = d % x.ndim
493528
m, n = input_shape[d], output_shape[d]
494529

495-
w = compute_weight_mat(
530+
w = _compute_weight_mat(
496531
m, n, scale[i], translation[i], kernel, antialias
497-
).astype(np.float32)
532+
).astype(output.dtype)
498533
output = np.tensordot(output, w, axes=(d, 0))
499534
output = np.moveaxis(output, -1, d)
500535

@@ -504,19 +539,6 @@ def _scale_and_translate(
504539
return output
505540

506541

507-
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
508-
"nearest": 0,
509-
"bilinear": 1,
510-
}
511-
AFFINE_TRANSFORM_FILL_MODES = {
512-
"constant",
513-
"nearest",
514-
"wrap",
515-
"mirror",
516-
"reflect",
517-
}
518-
519-
520542
def affine_transform(
521543
images,
522544
transform,
@@ -877,15 +899,6 @@ def compute_homography_matrix(start_points, end_points):
877899
return homography_matrix
878900

879901

880-
MAP_COORDINATES_FILL_MODES = {
881-
"constant",
882-
"nearest",
883-
"wrap",
884-
"mirror",
885-
"reflect",
886-
}
887-
888-
889902
def map_coordinates(
890903
inputs, coordinates, order, fill_mode="constant", fill_value=0.0
891904
):
@@ -1135,3 +1148,40 @@ def elastic_transform(
11351148
transformed_images = transformed_images.astype(input_dtype)
11361149

11371150
return transformed_images
1151+
1152+
1153+
def scale_and_translate(
1154+
images,
1155+
output_shape,
1156+
scale,
1157+
translation,
1158+
spatial_dims,
1159+
method,
1160+
antialias=True,
1161+
):
1162+
if method not in SCALE_AND_TRANSLATE_METHODS:
1163+
raise ValueError(
1164+
"Invalid value for argument `method`. Expected of one "
1165+
f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}"
1166+
)
1167+
if method in ("linear", "bilinear", "trilinear", "triangle"):
1168+
method = "linear"
1169+
elif method in ("cubic", "bicubic", "tricubic"):
1170+
method = "cubic"
1171+
1172+
images = convert_to_tensor(images)
1173+
scale = convert_to_tensor(scale)
1174+
translation = convert_to_tensor(translation)
1175+
kernel = _kernels[method]
1176+
dtype = backend.result_type(scale.dtype, translation.dtype)
1177+
scale = scale.astype(dtype)
1178+
translation = translation.astype(dtype)
1179+
return _scale_and_translate(
1180+
images,
1181+
output_shape,
1182+
spatial_dims,
1183+
scale,
1184+
translation,
1185+
kernel,
1186+
antialias,
1187+
)

keras/src/backend/openvino/image.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,59 @@ def affine_transform(
3131
)
3232

3333

34+
def perspective_transform(
35+
images,
36+
start_points,
37+
end_points,
38+
interpolation="bilinear",
39+
fill_value=0,
40+
data_format=None,
41+
):
42+
raise NotImplementedError(
43+
"`perspective_transform` is not supported with openvino backend"
44+
)
45+
46+
3447
def map_coordinates(
3548
inputs, coordinates, order, fill_mode="constant", fill_value=0
3649
):
3750
raise NotImplementedError(
3851
"`map_coordinates` is not supported with openvino backend"
3952
)
53+
54+
55+
def gaussian_blur(
56+
images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None
57+
):
58+
raise NotImplementedError(
59+
"`gaussian_blur` is not supported with openvino backend"
60+
)
61+
62+
63+
def elastic_transform(
64+
images,
65+
alpha=20.0,
66+
sigma=5.0,
67+
interpolation="bilinear",
68+
fill_mode="reflect",
69+
fill_value=0.0,
70+
seed=None,
71+
data_format=None,
72+
):
73+
raise NotImplementedError(
74+
"`elastic_transform` is not supported with openvino backend"
75+
)
76+
77+
78+
def scale_and_translate(
79+
images,
80+
output_shape,
81+
scale,
82+
translation,
83+
spatial_dims,
84+
method,
85+
antialias=True,
86+
):
87+
raise NotImplementedError(
88+
"`scale_and_translate` is not supported with openvino backend"
89+
)

0 commit comments

Comments
 (0)