@@ -40,6 +40,7 @@ class ImageReadMode(Enum):
4040 GRAY_ALPHA = 2
4141 RGB = 3
4242 RGB_ALPHA = 4
43+ RGBA = RGB_ALPHA # Alias for convenience
4344
4445
4546def read_file (path : str ) -> torch .Tensor :
@@ -92,7 +93,7 @@ def decode_png(
9293 Args:
9394 input (Tensor[1]): a one dimensional uint8 tensor containing
9495 the raw bytes of the PNG image.
95- mode (ImageReadMode): the read mode used for optionally
96+ mode (str or ImageReadMode): the read mode used for optionally
9697 converting the image. Default: ``ImageReadMode.UNCHANGED``.
9798 See `ImageReadMode` class for more information on various
9899 available modes.
@@ -104,6 +105,8 @@ def decode_png(
104105 """
105106 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
106107 _log_api_usage_once (decode_png )
108+ if isinstance (mode , str ):
109+ mode = ImageReadMode [mode .upper ()]
107110 output = torch .ops .image .decode_png (input , mode .value , apply_exif_orientation )
108111 return output
109112
@@ -168,7 +171,7 @@ def decode_jpeg(
168171 input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
169172 the raw bytes of the JPEG image. The tensor(s) must be on CPU,
170173 regardless of the ``device`` parameter.
171- mode (ImageReadMode): the read mode used for optionally
174+ mode (str or ImageReadMode): the read mode used for optionally
172175 converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``,
173176 ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
174177 Default: ``ImageReadMode.UNCHANGED``.
@@ -198,6 +201,8 @@ def decode_jpeg(
198201 _log_api_usage_once (decode_jpeg )
199202 if isinstance (device , str ):
200203 device = torch .device (device )
204+ if isinstance (mode , str ):
205+ mode = ImageReadMode [mode .upper ()]
201206
202207 if isinstance (input , list ):
203208 if len (input ) == 0 :
@@ -298,7 +303,7 @@ def decode_image(
298303 input (Tensor or str or ``pathlib.Path``): The image to decode. If a
299304 tensor is passed, it must be one dimensional uint8 tensor containing
300305 the raw bytes of the image. Otherwise, this must be a path to the image file.
301- mode (ImageReadMode): the read mode used for optionally converting the image.
306+ mode (str or ImageReadMode): the read mode used for optionally converting the image.
302307 Default: ``ImageReadMode.UNCHANGED``.
303308 See ``ImageReadMode`` class for more information on various
304309 available modes. Only applies to JPEG and PNG images.
@@ -312,6 +317,8 @@ def decode_image(
312317 _log_api_usage_once (decode_image )
313318 if not isinstance (input , torch .Tensor ):
314319 input = read_file (str (input ))
320+ if isinstance (mode , str ):
321+ mode = ImageReadMode [mode .upper ()]
315322 output = torch .ops .image .decode_image (input , mode .value , apply_exif_orientation )
316323 return output
317324
@@ -360,7 +367,7 @@ def decode_webp(
360367 Args:
361368 input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
362369 the raw bytes of the WEBP image.
363- mode (ImageReadMode): The read mode used for optionally
370+ mode (str or ImageReadMode): The read mode used for optionally
364371 converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
365372 Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
366373
@@ -369,6 +376,8 @@ def decode_webp(
369376 """
370377 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
371378 _log_api_usage_once (decode_webp )
379+ if isinstance (mode , str ):
380+ mode = ImageReadMode [mode .upper ()]
372381 return torch .ops .image .decode_webp (input , mode .value )
373382
374383
@@ -389,7 +398,7 @@ def _decode_avif(
389398 Args:
390399 input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
391400 the raw bytes of the AVIF image.
392- mode (ImageReadMode): The read mode used for optionally
401+ mode (str or ImageReadMode): The read mode used for optionally
393402 converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
394403 Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
395404
@@ -398,6 +407,8 @@ def _decode_avif(
398407 """
399408 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
400409 _log_api_usage_once (_decode_avif )
410+ if isinstance (mode , str ):
411+ mode = ImageReadMode [mode .upper ()]
401412 return torch .ops .image .decode_avif (input , mode .value )
402413
403414
@@ -415,7 +426,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
415426 Args:
416427 input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
417428 the raw bytes of the HEIC image.
418- mode (ImageReadMode): The read mode used for optionally
429+ mode (str or ImageReadMode): The read mode used for optionally
419430 converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
420431 Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
421432
@@ -424,4 +435,6 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
424435 """
425436 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
426437 _log_api_usage_once (_decode_heic )
438+ if isinstance (mode , str ):
439+ mode = ImageReadMode [mode .upper ()]
427440 return torch .ops .image .decode_heic (input , mode .value )
0 commit comments