@@ -69,33 +69,20 @@ def render_image(
6969 if not output_format :
7070 output_format = ImageType .jpeg if mask .all () else ImageType .png
7171
72- if output_format == ImageType .png and data .dtype not in ["uint8" , "uint16" ]:
72+ # format-specific valid dtypes
73+ format_dtypes = {
74+ ImageType .png : ["uint8" , "uint16" ],
75+ ImageType .jpeg : ["uint8" ],
76+ ImageType .jpg : ["uint8" ],
77+ ImageType .webp : ["uint8" ],
78+ ImageType .jp2 : ["uint8" , "int16" , "uint16" ],
79+ }
80+
81+ valid_dtypes = format_dtypes .get (output_format , [])
82+ if valid_dtypes and data .dtype not in valid_dtypes :
7383 warnings .warn (
74- f"Invalid type: `{ data .dtype } ` for the `{ output_format } ` driver. Data will be rescaled using min/max type bounds or dataset_statistics." ,
75- InvalidDatatypeWarning ,
76- stacklevel = 1 ,
77- )
78- data = rescale_array (data , mask , in_range = datatype_range )
79-
80- elif output_format in [
81- ImageType .jpeg ,
82- ImageType .jpg ,
83- ImageType .webp ,
84- ] and data .dtype not in ["uint8" ]:
85- warnings .warn (
86- f"Invalid type: `{ data .dtype } ` for the `{ output_format } ` driver. Data will be rescaled using min/max type bounds or dataset_statistics." ,
87- InvalidDatatypeWarning ,
88- stacklevel = 1 ,
89- )
90- data = rescale_array (data , mask , in_range = datatype_range )
91-
92- elif output_format == ImageType .jp2 and data .dtype not in [
93- "uint8" ,
94- "int16" ,
95- "uint16" ,
96- ]:
97- warnings .warn (
98- f"Invalid type: `{ data .dtype } ` for the `{ output_format } ` driver. Data will be rescaled using min/max type bounds or dataset_statistics." ,
84+ f"Invalid type: `{ data .dtype } ` for the `{ output_format } ` driver. "
85+ "Data will be rescaled using min/max type bounds or dataset_statistics." ,
9986 InvalidDatatypeWarning ,
10087 stacklevel = 1 ,
10188 )
0 commit comments