@@ -560,6 +560,7 @@ def affine_transform(
560
560
f"{ AFFINE_TRANSFORM_FILL_MODES } . Received: fill_mode={ fill_mode } "
561
561
)
562
562
563
+ images = convert_to_tensor (images )
563
564
transform = convert_to_tensor (transform )
564
565
565
566
if len (images .shape ) not in (3 , 4 ):
@@ -575,10 +576,11 @@ def affine_transform(
575
576
f"transform.shape={ transform .shape } "
576
577
)
577
578
578
- # scipy.ndimage.map_coordinates lacks support for half precision.
579
- input_dtype = images .dtype
580
- if input_dtype == "float16" :
581
- images = images .astype ("float32" )
579
+ # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.
580
+ input_dtype = backend .standardize_dtype (images .dtype )
581
+ compute_dtype = backend .result_type (input_dtype , "float32" )
582
+ images = images .astype (compute_dtype )
583
+ transform = transform .astype (compute_dtype )
582
584
583
585
# unbatched case
584
586
need_squeeze = False
@@ -622,7 +624,7 @@ def affine_transform(
622
624
# transform the indices
623
625
coordinates = np .einsum ("Bhwij, Bjk -> Bhwik" , indices , transform )
624
626
coordinates = np .moveaxis (coordinates , source = - 1 , destination = 1 )
625
- coordinates += np .reshape (offset , newshape = (* offset .shape , 1 , 1 , 1 ))
627
+ coordinates += np .reshape (offset , (* offset .shape , 1 , 1 , 1 ))
626
628
627
629
# apply affine transformation
628
630
affined = np .stack (
@@ -643,9 +645,7 @@ def affine_transform(
643
645
affined = np .transpose (affined , (0 , 3 , 1 , 2 ))
644
646
if need_squeeze :
645
647
affined = np .squeeze (affined , axis = 0 )
646
- if input_dtype == "float16" :
647
- affined = affined .astype (input_dtype )
648
- return affined
648
+ return affined .astype (input_dtype )
649
649
650
650
651
651
def perspective_transform (
@@ -758,6 +758,14 @@ def perspective_transform(
758
758
759
759
760
760
def compute_homography_matrix (start_points , end_points ):
761
+ start_points = convert_to_tensor (start_points )
762
+ end_points = convert_to_tensor (end_points )
763
+ dtype = backend .result_type (start_points .dtype , end_points .dtype , float )
764
+ # `np.linalg.solve` lacks support for float16 and bfloat16.
765
+ compute_dtype = backend .result_type (dtype , "float32" )
766
+ start_points = start_points .astype (dtype )
767
+ end_points = end_points .astype (dtype )
768
+
761
769
start_x1 , start_y1 = start_points [:, 0 , 0 ], start_points [:, 0 , 1 ]
762
770
start_x2 , start_y2 = start_points [:, 1 , 0 ], start_points [:, 1 , 1 ]
763
771
start_x3 , start_y3 = start_points [:, 2 , 0 ], start_points [:, 2 , 1 ]
@@ -892,11 +900,11 @@ def compute_homography_matrix(start_points, end_points):
892
900
axis = - 1 ,
893
901
)
894
902
target_vector = np .expand_dims (target_vector , axis = - 1 )
895
-
903
+ coefficient_matrix = coefficient_matrix .astype (compute_dtype )
904
+ target_vector = target_vector .astype (compute_dtype )
896
905
homography_matrix = np .linalg .solve (coefficient_matrix , target_vector )
897
906
homography_matrix = np .reshape (homography_matrix , [- 1 , 8 ])
898
-
899
- return homography_matrix
907
+ return homography_matrix .astype (dtype )
900
908
901
909
902
910
def map_coordinates (
@@ -950,10 +958,14 @@ def map_coordinates(
950
958
)
951
959
else :
952
960
padded = np .pad (inputs , padding , mode = pad_mode )
961
+
962
+ # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.
963
+ if backend .is_float_dtype (padded .dtype ):
964
+ padded = padded .astype ("float32" )
953
965
result = scipy .ndimage .map_coordinates (
954
966
padded , shifted_coords , order = order , mode = fill_mode , cval = fill_value
955
967
)
956
- return result
968
+ return result . astype ( inputs . dtype )
957
969
958
970
959
971
def gaussian_blur (
@@ -979,7 +991,11 @@ def _get_gaussian_kernel2d(size, sigma):
979
991
images = convert_to_tensor (images )
980
992
kernel_size = convert_to_tensor (kernel_size )
981
993
sigma = convert_to_tensor (sigma )
982
- input_dtype = images .dtype
994
+ input_dtype = backend .standardize_dtype (images .dtype )
995
+ # `scipy.signal.convolve2d` lacks support for float16 and bfloat16.
996
+ compute_dtype = backend .result_type (input_dtype , "float32" )
997
+ images = images .astype (compute_dtype )
998
+ sigma = sigma .astype (compute_dtype )
983
999
984
1000
if len (images .shape ) not in (3 , 4 ):
985
1001
raise ValueError (
@@ -1022,8 +1038,7 @@ def _get_gaussian_kernel2d(size, sigma):
1022
1038
blurred_images = np .transpose (blurred_images , (0 , 3 , 1 , 2 ))
1023
1039
if need_squeeze :
1024
1040
blurred_images = np .squeeze (blurred_images , axis = 0 )
1025
-
1026
- return blurred_images
1041
+ return blurred_images .astype (input_dtype )
1027
1042
1028
1043
1029
1044
def elastic_transform (
0 commit comments