Skip to content

Commit dc96ff0

Browse files
committed
type: Add overloads to DataobjImage.get_fdata()
Improve type-checkers' ability to predict dtype
1 parent 4008cc5 commit dc96ff0

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

nibabel/dataobj_images.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from .fileholders import FileMap
2525
from .filename_parser import FileSpec
2626

27+
FT = ty.TypeVar('FT', bound=np.floating)
28+
F16 = ty.Literal['float16', 'f2', '|f2', '=f2', '<f2', '>f2']
29+
F32 = ty.Literal['float32', 'f4', '|f4', '=f4', '<f4', '>f4']
30+
F64 = ty.Literal['float64', 'f8', '|f8', '=f8', '<f8', '>f8']
31+
Caching = ty.Literal['fill', 'unchanged']
32+
2733
ArrayImgT = ty.TypeVar('ArrayImgT', bound='DataobjImage')
2834

2935

@@ -39,7 +45,7 @@ def __init__(
3945
header: FileBasedHeader | ty.Mapping | None = None,
4046
extra: ty.Mapping | None = None,
4147
file_map: FileMap | None = None,
42-
):
48+
) -> None:
4349
"""Initialize dataobj image
4450
4551
The datobj image is a combination of (dataobj, header), with optional
@@ -224,11 +230,33 @@ def get_data(self, caching='fill'):
224230
self._data_cache = data
225231
return data
226232

233+
# Types and dtypes, e.g., np.float64 or np.dtype('f8')
234+
@ty.overload
235+
def get_fdata(
236+
self, *, caching: Caching = 'fill', dtype: type[FT] | np.dtype[FT]
237+
) -> npt.NDArray[FT]: ...
238+
@ty.overload
239+
def get_fdata(self, caching: Caching, dtype: type[FT] | np.dtype[FT]) -> npt.NDArray[FT]: ...
240+
# Support string literals
241+
@ty.overload
242+
def get_fdata(self, caching: Caching, dtype: F16) -> npt.NDArray[np.float16]: ...
243+
@ty.overload
244+
def get_fdata(self, caching: Caching, dtype: F32) -> npt.NDArray[np.float32]: ...
245+
@ty.overload
246+
def get_fdata(self, *, caching: Caching = 'fill', dtype: F16) -> npt.NDArray[np.float16]: ...
247+
@ty.overload
248+
def get_fdata(self, *, caching: Caching = 'fill', dtype: F32) -> npt.NDArray[np.float32]: ...
249+
# Double-up on float64 literals and the default (no arguments) case
250+
@ty.overload
251+
def get_fdata(
252+
self, caching: Caching = 'fill', dtype: F64 = 'f8'
253+
) -> npt.NDArray[np.float64]: ...
254+
227255
def get_fdata(
228256
self,
229-
caching: ty.Literal['fill', 'unchanged'] = 'fill',
257+
caching: Caching = 'fill',
230258
dtype: npt.DTypeLike = np.float64,
231-
) -> np.ndarray[ty.Any, np.dtype[np.floating]]:
259+
) -> npt.NDArray[np.floating]:
232260
"""Return floating point image data with necessary scaling applied
233261
234262
The image ``dataobj`` property can be an array proxy or an array. An

0 commit comments

Comments
 (0)