Skip to content

Commit 9d632ca

Browse files
committed
Remove unused functions and fix some types
1 parent ecbefdb commit 9d632ca

File tree

2 files changed

+82
-197
lines changed

2 files changed

+82
-197
lines changed

src/pydvl/utils/array.py

Lines changed: 80 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,36 @@
33
It provides a consistent interface for operations on both NumPy arrays and PyTorch tensors.
44
55
The functions in this module are designed to:
6+
67
1. Detect array types automatically (numpy.ndarray or torch.Tensor)
78
2. Perform operations using the appropriate library
8-
3. Preserve the input type in the output
9+
3. Preserve the input type in the output, except for functions intended to operate on
10+
indices, which always return NDArrays for convenience.
911
4. Minimize unnecessary type conversions
1012
11-
Usage examples:
13+
??? example "Some examples"
1214
13-
```python
14-
import numpy as np
15-
import torch
16-
from pydvl.utils.array import array_zeros, array_concatenate, is_tensor
15+
```python
16+
import numpy as np
17+
import torch
18+
from pydvl.utils.array import array_zeros, array_concatenate, is_tensor
1719
18-
# Works with NumPy arrays
19-
x_np = np.array([1, 2, 3])
20-
zeros_np = array_zeros((3,), like=x_np) # Returns numpy.ndarray
20+
# Works with NumPy arrays
21+
x_np = np.array([1, 2, 3])
22+
zeros_np = array_zeros((3, ), like=x_np) # Returns numpy.ndarray
2123
22-
# Works with PyTorch tensors
23-
x_torch = torch.tensor([1, 2, 3])
24-
zeros_torch = array_zeros((3,), like=x_torch) # Returns torch.Tensor
24+
# Works with PyTorch tensors
25+
x_torch = torch.tensor([1, 2, 3])
26+
zeros_torch = array_zeros((3, ), like=x_torch) # Returns torch.Tensor
2527
26-
# Type checking
27-
is_tensor(x_torch) # Returns True
28-
is_tensor(x_np) # Returns False
28+
# Type checking
29+
is_tensor(x_torch) # Returns True
30+
is_tensor(x_np) # Returns False
2931
30-
# Operations preserve types
31-
result = array_concatenate([x_np, zeros_np]) # Returns numpy.ndarray
32-
result = array_concatenate([x_torch, zeros_torch]) # Returns torch.Tensor
33-
```
32+
# Operations preserve types
33+
result = array_concatenate([x_np, zeros_np]) # Returns numpy.ndarray
34+
result = array_concatenate([x_torch, zeros_torch]) # Returns torch.Tensor
35+
```
3436
3537
The module uses a TypeVar `ArrayT` to ensure type preservation across functions,
3638
allowing for proper static type checking with both array types.
@@ -39,6 +41,7 @@
3941
from __future__ import annotations
4042

4143
import warnings
44+
from numbers import Number
4245
from types import ModuleType
4346
from typing import (
4447
TYPE_CHECKING,
@@ -70,12 +73,10 @@
7073
"array_ones",
7174
"array_zeros_like",
7275
"array_ones_like",
73-
"array_where",
7476
"array_unique",
7577
"array_concatenate",
7678
"array_equal",
7779
"array_arange",
78-
"array_slice",
7980
"array_index",
8081
"array_exp",
8182
"array_count_nonzero",
@@ -148,6 +149,8 @@ class Array(Protocol[DT]):
148149
in NumPy and PyTorch arrays.
149150
"""
150151

152+
nbytes: int
153+
151154
@property
152155
def shape(self) -> tuple[int, ...]: ...
153156

@@ -287,98 +290,65 @@ def array_ones(
287290

288291

289292
@overload
290-
def array_zeros_like(array: NDArray, dtype: Any | None = None) -> NDArray: ...
293+
def array_zeros_like(a: NDArray, dtype: Any | None = None) -> NDArray: ...
291294

292295

293296
@overload
294-
def array_zeros_like(array: Tensor, dtype: Any | None = None) -> Tensor: ...
297+
def array_zeros_like(a: Tensor, dtype: Any | None = None) -> Tensor: ...
295298

296299

297-
def array_zeros_like(array: Array, dtype: Any | None = None) -> Array:
300+
def array_zeros_like(a: NDArray | Tensor, dtype: Any | None = None) -> NDArray | Tensor:
298301
"""
299302
Create a zero-filled array with the same shape and type as `array`.
300303
301304
Args:
302-
array: Reference array (numpy array or torch tensor).
305+
a: Reference array (numpy array or torch tensor).
303306
dtype: Data type (optional).
304307
305308
Returns:
306309
An array of zeros matching the input type.
310+
311+
Raises:
312+
TypeError:
307313
"""
308-
if is_tensor(array):
314+
if is_tensor(a):
309315
assert torch is not None # Keep mypy happy
310-
tensor_array = cast(Tensor, array)
311-
return cast(Array, torch.zeros_like(tensor_array, dtype=dtype))
312-
return cast(Array, np.zeros_like(array, dtype=dtype))
316+
tensor_array = cast(Tensor, a)
317+
return cast(Tensor, torch.zeros_like(tensor_array, dtype=dtype))
318+
elif is_numpy(a):
319+
return cast(NDArray, np.zeros_like(a, dtype=dtype))
320+
raise TypeError(f"Unsupported array type: {type(a).__name__}")
313321

314322

315323
@overload
316-
def array_ones_like(array: NDArray, dtype: Any | None = None) -> NDArray: ...
324+
def array_ones_like(a: NDArray, dtype: Any | None = None) -> NDArray: ...
317325

318326

319327
@overload
320-
def array_ones_like(array: Tensor, dtype: Any | None = None) -> Tensor: ...
328+
def array_ones_like(a: Tensor, dtype: Any | None = None) -> Tensor: ...
321329

322330

323-
def array_ones_like(array: Array, dtype: Any | None = None) -> Array:
331+
def array_ones_like(a: NDArray | Tensor, dtype: Any | None = None) -> NDArray | Tensor:
324332
"""
325333
Create a one-filled array with the same shape and type as `array`.
326334
327335
Args:
328-
array: Reference array (numpy array or torch tensor).
336+
a: Reference array (numpy array or torch tensor).
329337
dtype: Data type (optional).
330338
331339
Returns:
332340
An array of ones matching the input type.
333-
"""
334-
if is_numpy(array):
335-
return cast(Array, np.ones_like(array, dtype=dtype))
336-
elif is_tensor(array):
337-
assert torch is not None
338-
tensor_array = cast(Tensor, array)
339-
return cast(Array, torch.ones_like(tensor_array, dtype=dtype))
340-
raise TypeError(f"Unsupported array type: {type(array).__name__}")
341-
342-
343-
@overload
344-
def array_where(condition: NDArray, x: NDArray, y: NDArray) -> NDArray: ...
345-
346-
347-
@overload
348-
def array_where(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: ...
349-
350-
351-
def array_where(condition: Array, x: Array, y: Array) -> Array:
352-
"""
353-
Return elements chosen from x or y depending on condition.
354-
355-
Args:
356-
condition: Boolean mask.
357-
x: Values selected where condition is True.
358-
y: Values selected where condition is False.
359341
360-
Returns:
361-
An array with elements from x or y, following the input type.
342+
Raises:
343+
TypeError:
362344
"""
363-
# If any of the inputs is a tensor and torch is available, work with torch.
364-
if any(is_tensor(a) for a in (condition, x, y)):
345+
if is_numpy(a):
346+
return cast(NDArray, np.ones_like(a, dtype=dtype))
347+
elif is_tensor(a):
365348
assert torch is not None
366-
device = None
367-
for a in (condition, x, y):
368-
if is_tensor(a):
369-
device = cast(Tensor, a).device
370-
break
371-
372-
condition_tensor = (
373-
condition
374-
if is_tensor(condition)
375-
else torch.as_tensor(to_numpy(condition), device=device)
376-
)
377-
x_tensor = x if is_tensor(x) else torch.as_tensor(to_numpy(x), device=device)
378-
y_tensor = y if is_tensor(y) else torch.as_tensor(to_numpy(y), device=device)
379-
return cast(Array, torch.where(condition_tensor, x_tensor, y_tensor))
380-
else:
381-
return cast(Array, np.where(condition, x, y))
349+
tensor_array = cast(Tensor, a)
350+
return cast(Array, torch.ones_like(tensor_array, dtype=dtype))
351+
raise TypeError(f"Unsupported array type: {type(a).__name__}")
382352

383353

384354
@overload
@@ -402,12 +372,12 @@ def array_unique(
402372
@overload
403373
def array_unique(
404374
array: Tensor, return_index: Literal[True], **kwargs: Any
405-
) -> Tuple[Tensor, Tensor]: ...
375+
) -> Tuple[Tensor, NDArray]: ...
406376

407377

408378
def array_unique(
409-
array: ArrayT, return_index: bool = False, **kwargs: Any
410-
) -> Union[ArrayT, Tuple[ArrayT, NDArray]]:
379+
array: NDArray | Tensor, return_index: bool = False, **kwargs: Any
380+
) -> Union[NDArray | Tensor, Tuple[NDArray | Tensor, NDArray]]:
411381
"""
412382
Return the unique elements in an array, optionally with indices of their first
413383
occurrences.
@@ -436,8 +406,8 @@ def array_unique(
436406
indices_tensor = torch.tensor(
437407
indices, dtype=torch.long, device=tensor_array.device
438408
)
439-
return cast(Tuple[ArrayT, NDArray], (result, indices_tensor.cpu().numpy()))
440-
return cast(ArrayT, result)
409+
return cast(Tuple[Tensor, NDArray], (result, indices_tensor.cpu().numpy()))
410+
return cast(Tensor, result)
441411
else: # Fallback to numpy approach.
442412
numpy_array = to_numpy(array)
443413
if return_index:
@@ -447,15 +417,15 @@ def array_unique(
447417
return_index=True,
448418
**{k: v for k, v in kwargs.items() if k != "return_index"},
449419
)
450-
return cast(Tuple[ArrayT, NDArray], (unique_vals, indices))
420+
return cast(Tuple[NDArray, NDArray], (unique_vals, indices))
451421
else:
452422
# Simple case - just unique values
453423
result = np.unique(
454424
numpy_array,
455425
return_index=False,
456426
**{k: v for k, v in kwargs.items() if k != "return_index"},
457427
)
458-
return cast(ArrayT, result)
428+
return cast(NDArray, result)
459429

460430

461431
@overload
@@ -466,7 +436,9 @@ def array_concatenate(arrays: Sequence[NDArray], axis: int = 0) -> NDArray: ...
466436
def array_concatenate(arrays: Sequence[Tensor], axis: int = 0) -> Tensor: ...
467437

468438

469-
def array_concatenate(arrays: Sequence[Array], axis: int = 0) -> Array:
439+
def array_concatenate(
440+
arrays: Sequence[NDArray | Tensor], axis: int = 0
441+
) -> NDArray | Tensor:
470442
"""
471443
Join a sequence of arrays along an existing axis.
472444
@@ -496,10 +468,10 @@ def array_concatenate(arrays: Sequence[Array], axis: int = 0) -> Array:
496468
tensor_arrays.append(a)
497469
else:
498470
tensor_arrays.append(torch.as_tensor(to_numpy(a), device=device))
499-
return cast(Array, torch.cat(tensor_arrays, dim=axis))
471+
return cast(Tensor, torch.cat(tensor_arrays, dim=axis))
500472
# Otherwise, convert all arrays to numpy arrays.
501473
numpy_arrays = [to_numpy(a) for a in arrays]
502-
return cast(Array, np.concatenate(numpy_arrays, axis=axis))
474+
return cast(NDArray, np.concatenate(numpy_arrays, axis=axis))
503475

504476

505477
def array_equal(array1: Array[Any], array2: Array[Any]) -> bool:
@@ -569,35 +541,6 @@ def array_arange(
569541
return cast(Array, np.arange(start, stop, step, dtype=dtype)) # type: ignore
570542

571543

572-
@overload
573-
def array_slice(array: NDArray, indices: Any) -> NDArray: ...
574-
575-
576-
@overload
577-
def array_slice(array: Tensor, indices: Any) -> Tensor: ...
578-
579-
580-
def array_slice(array: Array, indices: Any) -> Array:
581-
"""
582-
Slice an array in a type-agnostic way.
583-
584-
Args:
585-
array: The array to be sliced.
586-
indices: The slicing indices (int, slice, list, etc.).
587-
588-
Returns:
589-
A sliced array with the same type as the input.
590-
591-
Raises:
592-
TypeError: If the input array does not support indexing.
593-
"""
594-
if not hasattr(array, "__getitem__"):
595-
raise TypeError(
596-
f"Provided object of type {type(array).__name__} is not indexable."
597-
)
598-
return array[indices]
599-
600-
601544
def array_index(array: Array, key: Array, dim: int = 0) -> Array:
602545
"""
603546
Index into an array along the specified dimension.
@@ -703,6 +646,10 @@ def stratified_split_indices(
703646
return cast(Tuple[ArrayT, ArrayT], (train_indices, test_indices))
704647

705648

649+
@overload
650+
def atleast1d(a: Number) -> NDArray: ...
651+
652+
706653
@overload
707654
def atleast1d(a: NDArray) -> NDArray: ...
708655

@@ -711,25 +658,25 @@ def atleast1d(a: NDArray) -> NDArray: ...
711658
def atleast1d(a: Tensor) -> Tensor: ...
712659

713660

714-
def atleast1d(a: Array) -> Array:
715-
"""
716-
Ensures that the array is at least 1D.
661+
def atleast1d(a: NDArray | Tensor | Number) -> NDArray | Tensor:
662+
"""Ensures that the input is at least 1D.
663+
664+
For scalar builtin types, the output is an NDArray. Scalar tensors are converted to
665+
1D tensors
717666
718667
Args:
719-
a: Input array-like object.
668+
a: Input array-like object or a scalar.
720669
721670
Returns:
722-
The array, as a 1D structure.
671+
The input, as a 1D structure.
723672
"""
724-
if is_numpy(a):
725-
return cast(Array, np.atleast_1d(a))
673+
if is_numpy(a) or np.isscalar(a):
674+
return cast(NDArray, np.atleast_1d(a)) # type: ignore
726675
if is_tensor(a):
727676
assert torch is not None
728-
tensor_array = cast(Tensor, a)
729-
return cast(
730-
Array, tensor_array.unsqueeze(0) if tensor_array.ndim == 0 else tensor_array
731-
)
732-
return cast(Array, np.atleast_1d(np.array(a)))
677+
t = cast(Tensor, a)
678+
return cast(Tensor, t.unsqueeze(0) if t.ndim == 0 else t)
679+
raise TypeError(f"Unsupported array or scalar type: {type(a).__name__}")
733680

734681

735682
@overload
@@ -755,13 +702,13 @@ def check_X_y(
755702

756703

757704
def check_X_y(
758-
X: Array,
759-
y: Array,
705+
X: NDArray | Tensor,
706+
y: NDArray | Tensor,
760707
*,
761708
multi_output: bool = False,
762709
estimator: str | object | None = None,
763710
copy: bool = False,
764-
) -> Tuple[Array, Array]:
711+
) -> Tuple[NDArray | Tensor, NDArray | Tensor]:
765712
"""
766713
Validate X and y mimicking the functionality of sklearn's check_X_y.
767714

0 commit comments

Comments
 (0)