33from abc import ABC , abstractmethod
44from glob import glob
55from pathlib import Path
6- from typing import Callable , List , Optional , Tuple , Union
6+ from typing import Any , Callable , List , Optional , Tuple , Union
77
88import numpy as np
99import torch
1010from PIL import Image
1111
1212from ..io .image import decode_png , read_file
13+ from .folder import default_loader
1314from .utils import _read_pfm , verify_str_arg
1415from .vision import VisionDataset
1516
@@ -32,19 +33,22 @@ class FlowDataset(ABC, VisionDataset):
3233 # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3334 _has_builtin_flow_mask = False
3435
35- def __init__ (self , root : Union [str , Path ], transforms : Optional [Callable ] = None ) -> None :
36+ def __init__ (
37+ self ,
38+ root : Union [str , Path ],
39+ transforms : Optional [Callable ] = None ,
40+ loader : Callable [[str ], Any ] = default_loader ,
41+ ) -> None :
3642
3743 super ().__init__ (root = root )
3844 self .transforms = transforms
3945
4046 self ._flow_list : List [str ] = []
4147 self ._image_list : List [List [str ]] = []
48+ self ._loader = loader
4249
43- def _read_img (self , file_name : str ) -> Image .Image :
44- img = Image .open (file_name )
45- if img .mode != "RGB" :
46- img = img .convert ("RGB" ) # type: ignore[assignment]
47- return img
50+ def _read_img (self , file_name : str ) -> Union [Image .Image , torch .Tensor ]:
51+ return self ._loader (file_name )
4852
4953 @abstractmethod
5054 def _read_flow (self , file_name : str ):
@@ -70,9 +74,9 @@ def __getitem__(self, index: int) -> Union[T1, T2]:
7074
7175 if self ._has_builtin_flow_mask or valid_flow_mask is not None :
7276 # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
73- return img1 , img2 , flow , valid_flow_mask
77+ return img1 , img2 , flow , valid_flow_mask # type: ignore[return-value]
7478 else :
75- return img1 , img2 , flow
79+ return img1 , img2 , flow # type: ignore[return-value]
7680
7781 def __len__ (self ) -> int :
7882 return len (self ._image_list )
@@ -120,6 +124,9 @@ class Sintel(FlowDataset):
120124 ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
121125 ``valid_flow_mask`` is expected for consistency with other datasets which
122126 return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
127+ loader (callable, optional): A function to load an image given its path.
128+ By default, it uses PIL as its image loader, but users could also pass in
129+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
123130 """
124131
125132 def __init__ (
@@ -128,8 +135,9 @@ def __init__(
128135 split : str = "train" ,
129136 pass_name : str = "clean" ,
130137 transforms : Optional [Callable ] = None ,
138+ loader : Callable [[str ], Any ] = default_loader ,
131139 ) -> None :
132- super ().__init__ (root = root , transforms = transforms )
140+ super ().__init__ (root = root , transforms = transforms , loader = loader )
133141
134142 verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
135143 verify_str_arg (pass_name , "pass_name" , valid_values = ("clean" , "final" , "both" ))
@@ -186,12 +194,21 @@ class KittiFlow(FlowDataset):
186194 split (string, optional): The dataset split, either "train" (default) or "test"
187195 transforms (callable, optional): A function/transform that takes in
188196 ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
197+ loader (callable, optional): A function to load an image given its path.
198+ By default, it uses PIL as its image loader, but users could also pass in
199+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
189200 """
190201
191202 _has_builtin_flow_mask = True
192203
193- def __init__ (self , root : Union [str , Path ], split : str = "train" , transforms : Optional [Callable ] = None ) -> None :
194- super ().__init__ (root = root , transforms = transforms )
204+ def __init__ (
205+ self ,
206+ root : Union [str , Path ],
207+ split : str = "train" ,
208+ transforms : Optional [Callable ] = None ,
209+ loader : Callable [[str ], Any ] = default_loader ,
210+ ) -> None :
211+ super ().__init__ (root = root , transforms = transforms , loader = loader )
195212
196213 verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
197214
@@ -324,6 +341,9 @@ class FlyingThings3D(FlowDataset):
324341 ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
325342 ``valid_flow_mask`` is expected for consistency with other datasets which
326343 return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
344+ loader (callable, optional): A function to load an image given its path.
345+ By default, it uses PIL as its image loader, but users could also pass in
346+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
327347 """
328348
329349 def __init__ (
@@ -333,8 +353,9 @@ def __init__(
333353 pass_name : str = "clean" ,
334354 camera : str = "left" ,
335355 transforms : Optional [Callable ] = None ,
356+ loader : Callable [[str ], Any ] = default_loader ,
336357 ) -> None :
337- super ().__init__ (root = root , transforms = transforms )
358+ super ().__init__ (root = root , transforms = transforms , loader = loader )
338359
339360 verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
340361 split = split .upper ()
@@ -414,12 +435,21 @@ class HD1K(FlowDataset):
414435 split (string, optional): The dataset split, either "train" (default) or "test"
415436 transforms (callable, optional): A function/transform that takes in
416437 ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
438+ loader (callable, optional): A function to load an image given its path.
439+ By default, it uses PIL as its image loader, but users could also pass in
440+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
417441 """
418442
419443 _has_builtin_flow_mask = True
420444
421- def __init__ (self , root : Union [str , Path ], split : str = "train" , transforms : Optional [Callable ] = None ) -> None :
422- super ().__init__ (root = root , transforms = transforms )
445+ def __init__ (
446+ self ,
447+ root : Union [str , Path ],
448+ split : str = "train" ,
449+ transforms : Optional [Callable ] = None ,
450+ loader : Callable [[str ], Any ] = default_loader ,
451+ ) -> None :
452+ super ().__init__ (root = root , transforms = transforms , loader = loader )
423453
424454 verify_str_arg (split , "split" , valid_values = ("train" , "test" ))
425455
0 commit comments