Skip to content

Commit 1fc411f

Browse files
Add dtype tests for ops.image.*. (#21612)
1 parent a11ef39 commit 1fc411f

File tree

5 files changed

+240
-29
lines changed

5 files changed

+240
-29
lines changed

keras/src/backend/jax/image.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,9 @@ def gaussian_blur(
682682
):
683683
def _create_gaussian_kernel(kernel_size, sigma, dtype):
684684
def _get_gaussian_kernel1d(size, sigma):
685-
x = jnp.arange(size, dtype=dtype) - (size - 1) / 2
685+
x = jnp.arange(size, dtype=dtype) - jnp.array(
686+
(size - 1) / 2, dtype=dtype
687+
)
686688
kernel1d = jnp.exp(-0.5 * (x / sigma) ** 2)
687689
return kernel1d / jnp.sum(kernel1d)
688690

@@ -697,8 +699,8 @@ def _get_gaussian_kernel2d(size, sigma):
697699
return kernel
698700

699701
images = convert_to_tensor(images)
700-
sigma = convert_to_tensor(sigma)
701-
dtype = images.dtype
702+
dtype = backend.standardize_dtype(images.dtype)
703+
sigma = convert_to_tensor(sigma, dtype=dtype)
702704

703705
if len(images.shape) not in (3, 4):
704706
raise ValueError(

keras/src/backend/numpy/image.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def affine_transform(
560560
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
561561
)
562562

563+
images = convert_to_tensor(images)
563564
transform = convert_to_tensor(transform)
564565

565566
if len(images.shape) not in (3, 4):
@@ -575,10 +576,11 @@ def affine_transform(
575576
f"transform.shape={transform.shape}"
576577
)
577578

578-
# scipy.ndimage.map_coordinates lacks support for half precision.
579-
input_dtype = images.dtype
580-
if input_dtype == "float16":
581-
images = images.astype("float32")
579+
# `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.
580+
input_dtype = backend.standardize_dtype(images.dtype)
581+
compute_dtype = backend.result_type(input_dtype, "float32")
582+
images = images.astype(compute_dtype)
583+
transform = transform.astype(compute_dtype)
582584

583585
# unbatched case
584586
need_squeeze = False
@@ -622,7 +624,7 @@ def affine_transform(
622624
# transform the indices
623625
coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
624626
coordinates = np.moveaxis(coordinates, source=-1, destination=1)
625-
coordinates += np.reshape(offset, newshape=(*offset.shape, 1, 1, 1))
627+
coordinates += np.reshape(offset, (*offset.shape, 1, 1, 1))
626628

627629
# apply affine transformation
628630
affined = np.stack(
@@ -643,9 +645,7 @@ def affine_transform(
643645
affined = np.transpose(affined, (0, 3, 1, 2))
644646
if need_squeeze:
645647
affined = np.squeeze(affined, axis=0)
646-
if input_dtype == "float16":
647-
affined = affined.astype(input_dtype)
648-
return affined
648+
return affined.astype(input_dtype)
649649

650650

651651
def perspective_transform(
@@ -758,6 +758,14 @@ def perspective_transform(
758758

759759

760760
def compute_homography_matrix(start_points, end_points):
761+
start_points = convert_to_tensor(start_points)
762+
end_points = convert_to_tensor(end_points)
763+
dtype = backend.result_type(start_points.dtype, end_points.dtype, float)
764+
# `np.linalg.solve` lacks support for float16 and bfloat16.
765+
compute_dtype = backend.result_type(dtype, "float32")
766+
start_points = start_points.astype(dtype)
767+
end_points = end_points.astype(dtype)
768+
761769
start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]
762770
start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]
763771
start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1]
@@ -892,11 +900,11 @@ def compute_homography_matrix(start_points, end_points):
892900
axis=-1,
893901
)
894902
target_vector = np.expand_dims(target_vector, axis=-1)
895-
903+
coefficient_matrix = coefficient_matrix.astype(compute_dtype)
904+
target_vector = target_vector.astype(compute_dtype)
896905
homography_matrix = np.linalg.solve(coefficient_matrix, target_vector)
897906
homography_matrix = np.reshape(homography_matrix, [-1, 8])
898-
899-
return homography_matrix
907+
return homography_matrix.astype(dtype)
900908

901909

902910
def map_coordinates(
@@ -950,10 +958,14 @@ def map_coordinates(
950958
)
951959
else:
952960
padded = np.pad(inputs, padding, mode=pad_mode)
961+
962+
# `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.
963+
if backend.is_float_dtype(padded.dtype):
964+
padded = padded.astype("float32")
953965
result = scipy.ndimage.map_coordinates(
954966
padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value
955967
)
956-
return result
968+
return result.astype(inputs.dtype)
957969

958970

959971
def gaussian_blur(
@@ -979,7 +991,11 @@ def _get_gaussian_kernel2d(size, sigma):
979991
images = convert_to_tensor(images)
980992
kernel_size = convert_to_tensor(kernel_size)
981993
sigma = convert_to_tensor(sigma)
982-
input_dtype = images.dtype
994+
input_dtype = backend.standardize_dtype(images.dtype)
995+
# `scipy.signal.convolve2d` lacks support for float16 and bfloat16.
996+
compute_dtype = backend.result_type(input_dtype, "float32")
997+
images = images.astype(compute_dtype)
998+
sigma = sigma.astype(compute_dtype)
983999

9841000
if len(images.shape) not in (3, 4):
9851001
raise ValueError(
@@ -1022,8 +1038,7 @@ def _get_gaussian_kernel2d(size, sigma):
10221038
blurred_images = np.transpose(blurred_images, (0, 3, 1, 2))
10231039
if need_squeeze:
10241040
blurred_images = np.squeeze(blurred_images, axis=0)
1025-
1026-
return blurred_images
1041+
return blurred_images.astype(input_dtype)
10271042

10281043

10291044
def elastic_transform(

keras/src/backend/tensorflow/image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,9 @@ def _get_gaussian_kernel2d(size, sigma):
761761
return kernel
762762

763763
images = convert_to_tensor(images)
764-
kernel_size = convert_to_tensor(kernel_size)
765-
sigma = convert_to_tensor(sigma)
766-
dtype = images.dtype
764+
dtype = backend.standardize_dtype(images.dtype)
765+
kernel_size = convert_to_tensor(kernel_size, dtype=dtype)
766+
sigma = convert_to_tensor(sigma, dtype=dtype)
767767

768768
if len(images.shape) not in (3, 4):
769769
raise ValueError(

keras/src/backend/torch/image.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,9 @@ def perspective_transform(
468468
data_format = backend.standardize_data_format(data_format)
469469

470470
images = convert_to_tensor(images)
471-
start_points = torch.tensor(start_points, dtype=torch.float32)
472-
end_points = torch.tensor(end_points, dtype=torch.float32)
471+
dtype = backend.standardize_dtype(images.dtype)
472+
start_points = convert_to_tensor(start_points, dtype=dtype)
473+
end_points = convert_to_tensor(end_points, dtype=dtype)
473474

474475
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
475476
raise ValueError(
@@ -525,13 +526,15 @@ def perspective_transform(
525526
transforms = transforms.repeat(batch_size, 1)
526527

527528
grid_x, grid_y = torch.meshgrid(
528-
torch.arange(width, dtype=torch.float32, device=images.device),
529-
torch.arange(height, dtype=torch.float32, device=images.device),
529+
torch.arange(width, dtype=to_torch_dtype(dtype), device=images.device),
530+
torch.arange(height, dtype=to_torch_dtype(dtype), device=images.device),
530531
indexing="xy",
531532
)
532533

533534
output = torch.empty(
534-
[batch_size, height, width, channels], device=images.device
535+
[batch_size, height, width, channels],
536+
dtype=to_torch_dtype(dtype),
537+
device=images.device,
535538
)
536539

537540
for i in range(batch_size):
@@ -563,8 +566,13 @@ def perspective_transform(
563566

564567

565568
def compute_homography_matrix(start_points, end_points):
566-
start_points = convert_to_tensor(start_points, dtype=torch.float32)
567-
end_points = convert_to_tensor(end_points, dtype=torch.float32)
569+
start_points = convert_to_tensor(start_points)
570+
end_points = convert_to_tensor(end_points)
571+
dtype = backend.result_type(start_points.dtype, end_points.dtype, float)
572+
# `torch.linalg.solve` requires float32.
573+
compute_dtype = backend.result_type(dtype, "float32")
574+
start_points = cast(start_points, dtype)
575+
end_points = cast(end_points, dtype)
568576

569577
start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]
570578
start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]
@@ -700,9 +708,11 @@ def compute_homography_matrix(start_points, end_points):
700708
dim=-1,
701709
).unsqueeze(-1)
702710

711+
coefficient_matrix = cast(coefficient_matrix, compute_dtype)
712+
target_vector = cast(target_vector, compute_dtype)
703713
homography_matrix = torch.linalg.solve(coefficient_matrix, target_vector)
704714
homography_matrix = homography_matrix.reshape(-1, 8)
705-
715+
homography_matrix = cast(homography_matrix, dtype)
706716
return homography_matrix
707717

708718

0 commit comments

Comments
 (0)