|
15 | 15 | ```python |
16 | 16 | import numpy as np |
17 | 17 | import torch |
18 | | - from pydvl.utils.array import array_zeros, array_concatenate, is_tensor |
19 | | -
|
20 | | - # Works with NumPy arrays |
21 | | - x_np = np.array([1, 2, 3]) |
22 | | - zeros_np = array_zeros((3, ), like=x_np) # Returns numpy.ndarray |
23 | | -
|
24 | | - # Works with PyTorch tensors |
25 | | - x_torch = torch.tensor([1, 2, 3]) |
26 | | - zeros_torch = array_zeros((3, ), like=x_torch) # Returns torch.Tensor |
| 18 | + from pydvl.utils.array import array_concatenate, is_tensor |
27 | 19 |
|
28 | 20 | # Type checking |
29 | 21 | is_tensor(x_torch) # Returns True |
|
40 | 32 |
|
41 | 33 | from __future__ import annotations |
42 | 34 |
|
43 | | -import warnings |
44 | 35 | from numbers import Number |
45 | 36 | from types import ModuleType |
46 | 37 | from typing import ( |
|
69 | 60 | "is_numpy", |
70 | 61 | "to_tensor", |
71 | 62 | "to_numpy", |
72 | | - "array_zeros", |
73 | | - "array_ones", |
74 | | - "array_zeros_like", |
75 | | - "array_ones_like", |
76 | 63 | "array_unique", |
77 | 64 | "array_concatenate", |
78 | | - "array_equal", |
79 | | - "array_arange", |
80 | | - "array_index", |
81 | | - "array_exp", |
82 | 65 | "array_count_nonzero", |
83 | 66 | "array_nonzero", |
84 | 67 | "is_categorical", |
@@ -118,6 +101,12 @@ def try_torch_import(require: bool = False) -> ModuleType | None: |
118 | 101 | Tensor = Any if torch is None else torch.Tensor |
119 | 102 |
|
120 | 103 |
|
| 104 | +def require_torch() -> ModuleType: |
| 105 | + torch = try_torch_import(require=True) |
| 106 | + assert torch is not None |
| 107 | + return torch |
| 108 | + |
| 109 | + |
121 | 110 | def is_tensor(array: Any) -> bool: |
122 | 111 | """Check if an array is a PyTorch tensor.""" |
123 | 112 | return torch is not None and isinstance(array, torch.Tensor) |
@@ -226,131 +215,6 @@ def to_numpy(array: Array | ArrayLike) -> NDArray: |
226 | 215 | ShapeType = Union[int, Tuple[int, ...], List[int]] |
227 | 216 |
|
228 | 217 |
|
229 | | -def array_zeros( |
230 | | - shape: ShapeType, |
231 | | - *, |
232 | | - dtype: Any | None = None, |
233 | | - like: Array | None = None, |
234 | | -) -> Array: |
235 | | - """ |
236 | | - Create a zero-filled array with the same type as `like`. |
237 | | -
|
238 | | - Args: |
239 | | - shape: Desired shape for the new array. |
240 | | - dtype: Data type (optional). |
241 | | - like: Reference array (numpy array or torch tensor). |
242 | | -
|
243 | | - Returns: |
244 | | - An array of zeros (np.ndarray or torch.Tensor). |
245 | | - """ |
246 | | - if like is None: |
247 | | - return cast(Array, np.zeros(shape, dtype=dtype)) |
248 | | - |
249 | | - if is_numpy(like): |
250 | | - return cast(Array, np.zeros(shape, dtype=dtype)) |
251 | | - elif is_tensor(like): |
252 | | - assert torch is not None |
253 | | - like_tensor = cast(Tensor, like) |
254 | | - return cast(Array, torch.zeros(shape, dtype=dtype, device=like_tensor.device)) |
255 | | - else: |
256 | | - # In case 'like' is an unsupported type, fallback with numpy. |
257 | | - warnings.warn("Reference object type is not recognized. Falling back to numpy.") |
258 | | - return cast(Array, np.zeros(shape, dtype=dtype)) |
259 | | - |
260 | | - |
261 | | -def array_ones( |
262 | | - shape: ShapeType, |
263 | | - *, |
264 | | - dtype: Any | None = None, |
265 | | - like: Array | None = None, |
266 | | -) -> Array: |
267 | | - """ |
268 | | - Create a one-filled array with the same type as `like`. |
269 | | -
|
270 | | - Args: |
271 | | - shape: Desired shape for the new array. |
272 | | - dtype: Data type (optional). |
273 | | - like: Reference array (numpy array or torch tensor). |
274 | | -
|
275 | | - Returns: |
276 | | - An array of ones (np.ndarray or torch.Tensor). |
277 | | - """ |
278 | | - if like is None: |
279 | | - return cast(Array, np.ones(shape, dtype=dtype)) |
280 | | - |
281 | | - if is_numpy(like): |
282 | | - return cast(Array, np.ones(shape, dtype=dtype)) |
283 | | - elif is_tensor(like): |
284 | | - assert torch is not None |
285 | | - like_tensor = cast(Tensor, like) |
286 | | - return cast(Array, torch.ones(shape, dtype=dtype, device=like_tensor.device)) |
287 | | - else: |
288 | | - warnings.warn("Reference object type is not recognized. Falling back to numpy.") |
289 | | - return cast(Array, np.ones(shape, dtype=dtype)) |
290 | | - |
291 | | - |
292 | | -@overload |
293 | | -def array_zeros_like(a: NDArray, dtype: Any | None = None) -> NDArray: ... |
294 | | - |
295 | | - |
296 | | -@overload |
297 | | -def array_zeros_like(a: Tensor, dtype: Any | None = None) -> Tensor: ... |
298 | | - |
299 | | - |
300 | | -def array_zeros_like(a: NDArray | Tensor, dtype: Any | None = None) -> NDArray | Tensor: |
301 | | - """ |
302 | | - Create a zero-filled array with the same shape and type as `array`. |
303 | | -
|
304 | | - Args: |
305 | | - a: Reference array (numpy array or torch tensor). |
306 | | - dtype: Data type (optional). |
307 | | -
|
308 | | - Returns: |
309 | | - An array of zeros matching the input type. |
310 | | -
|
311 | | - Raises: |
312 | | - TypeError: |
313 | | - """ |
314 | | - if is_tensor(a): |
315 | | - assert torch is not None # Keep mypy happy |
316 | | - tensor_array = cast(Tensor, a) |
317 | | - return cast(Tensor, torch.zeros_like(tensor_array, dtype=dtype)) |
318 | | - elif is_numpy(a): |
319 | | - return cast(NDArray, np.zeros_like(a, dtype=dtype)) |
320 | | - raise TypeError(f"Unsupported array type: {type(a).__name__}") |
321 | | - |
322 | | - |
323 | | -@overload |
324 | | -def array_ones_like(a: NDArray, dtype: Any | None = None) -> NDArray: ... |
325 | | - |
326 | | - |
327 | | -@overload |
328 | | -def array_ones_like(a: Tensor, dtype: Any | None = None) -> Tensor: ... |
329 | | - |
330 | | - |
331 | | -def array_ones_like(a: NDArray | Tensor, dtype: Any | None = None) -> NDArray | Tensor: |
332 | | - """ |
333 | | - Create a one-filled array with the same shape and type as `array`. |
334 | | -
|
335 | | - Args: |
336 | | - a: Reference array (numpy array or torch tensor). |
337 | | - dtype: Data type (optional). |
338 | | -
|
339 | | - Returns: |
340 | | - An array of ones matching the input type. |
341 | | -
|
342 | | - Raises: |
343 | | - TypeError: |
344 | | - """ |
345 | | - if is_numpy(a): |
346 | | - return cast(NDArray, np.ones_like(a, dtype=dtype)) |
347 | | - elif is_tensor(a): |
348 | | - assert torch is not None |
349 | | - tensor_array = cast(Tensor, a) |
350 | | - return cast(Array, torch.ones_like(tensor_array, dtype=dtype)) |
351 | | - raise TypeError(f"Unsupported array type: {type(a).__name__}") |
352 | | - |
353 | | - |
354 | 218 | @overload |
355 | 219 | def array_unique( |
356 | 220 | array: NDArray, return_index: Literal[False] = False, **kwargs: Any |
@@ -474,122 +338,6 @@ def array_concatenate( |
474 | 338 | return cast(NDArray, np.concatenate(numpy_arrays, axis=axis)) |
475 | 339 |
|
476 | 340 |
|
477 | | -def array_equal(array1: Array[Any], array2: Array[Any]) -> bool: |
478 | | - """ |
479 | | - Check if two arrays are element-wise equal. |
480 | | -
|
481 | | - Args: |
482 | | - array1: First array. |
483 | | - array2: Second array. |
484 | | -
|
485 | | - Returns: |
486 | | - True if arrays are equal, otherwise False. |
487 | | - """ |
488 | | - if is_numpy(array1) and is_numpy(array2): |
489 | | - return bool(np.array_equal(array1, array2)) |
490 | | - elif is_tensor(array1) and is_tensor(array2) and torch is not None: |
491 | | - return bool(torch.equal(cast(Tensor, array1), cast(Tensor, array2))) |
492 | | - # Fall back to comparing numpy representations. |
493 | | - return bool(np.array_equal(to_numpy(array1), to_numpy(array2))) |
494 | | - |
495 | | - |
496 | | -def array_arange( |
497 | | - start: int, |
498 | | - stop: int | None = None, |
499 | | - step: int = 1, |
500 | | - *, |
501 | | - dtype: DT | None = None, |
502 | | - like: Array | None = None, |
503 | | -) -> Array[DT]: |
504 | | - """ |
505 | | - Create an array with evenly spaced values within a given interval. |
506 | | -
|
507 | | - Args: |
508 | | - start: Start of interval, or stop if stop is None. |
509 | | - stop: End of interval (exclusive). |
510 | | - step: Step size. |
511 | | - dtype: Data type (optional). |
512 | | - like: Reference array to infer the type and device. |
513 | | -
|
514 | | - Returns: |
515 | | - An array (numpy.ndarray or torch.Tensor) of the specified range. |
516 | | - """ |
517 | | - if stop is None: |
518 | | - start, stop = 0, start |
519 | | - |
520 | | - # If a reference is provided and is recognized as numpy or tensor, use it. |
521 | | - if like is not None: |
522 | | - if is_numpy(like): |
523 | | - return cast(Array, np.arange(start, stop, step, dtype=like.dtype)) |
524 | | - elif is_tensor(like): |
525 | | - assert torch is not None |
526 | | - like_tensor = cast(Tensor, like) |
527 | | - return cast( |
528 | | - Array, |
529 | | - torch.arange( |
530 | | - start, |
531 | | - stop, |
532 | | - step, |
533 | | - dtype=like_tensor.dtype, |
534 | | - device=like_tensor.device, |
535 | | - ), |
536 | | - ) |
537 | | - else: |
538 | | - warnings.warn( |
539 | | - "Reference object type not recognized. Falling back to numpy." |
540 | | - ) |
541 | | - return cast(Array, np.arange(start, stop, step, dtype=dtype)) # type: ignore |
542 | | - |
543 | | - |
544 | | -def array_index(array: Array, key: Array, dim: int = 0) -> Array: |
545 | | - """ |
546 | | - Index into an array along the specified dimension. |
547 | | -
|
548 | | - Args: |
549 | | - array: The input array. |
550 | | - key: The indices to select. |
551 | | - dim: Dimension along which to index. |
552 | | -
|
553 | | - Returns: |
554 | | - An array of the same type as the input with the specified indexing applied. |
555 | | -
|
556 | | - Raises: |
557 | | - ValueError: If the dimension is out of bounds. |
558 | | - """ |
559 | | - # Verify that dim is within valid range. |
560 | | - if not ( |
561 | | - 0 <= dim < array.shape[0] |
562 | | - if hasattr(array, "shape") and len(array.shape) > 0 |
563 | | - else 0 |
564 | | - ): |
565 | | - raise ValueError( |
566 | | - f"Dimension {dim} is out of bounds for array with shape " |
567 | | - f"{getattr(array, 'shape', None)}." |
568 | | - ) |
569 | | - |
570 | | - if is_numpy(array): |
571 | | - # Handle indexing along specified dimension. |
572 | | - if dim == 0: |
573 | | - return array[key] |
574 | | - elif dim == 1: |
575 | | - return array[:, key] |
576 | | - else: |
577 | | - idx = tuple(slice(None) if i != dim else key for i in range(array.ndim)) |
578 | | - return array[idx] |
579 | | - elif is_tensor(array) and torch is not None: |
580 | | - arr_tensor = cast(Tensor, array) |
581 | | - # Ensure key is a tensor. |
582 | | - key_tensor = ( |
583 | | - key |
584 | | - if is_tensor(key) |
585 | | - else torch.as_tensor(to_numpy(key), device=arr_tensor.device) |
586 | | - ) |
587 | | - key_long = cast(Tensor, key_tensor).to(torch.long) |
588 | | - return cast(Array, torch.index_select(arr_tensor, dim, key_long)) |
589 | | - else: |
590 | | - raise TypeError("Unsupported array type for indexing.") |
591 | | - |
592 | | - |
593 | 341 | ArrayT = TypeVar("ArrayT", bound=Array, contravariant=True) |
594 | 342 | ArrayRetT = TypeVar("ArrayRetT", bound=Array, covariant=True) |
595 | 343 |
|
@@ -835,31 +583,6 @@ def check_X_y_torch( |
835 | 583 | return X, y |
836 | 584 |
|
837 | 585 |
|
838 | | -def require_torch() -> ModuleType: |
839 | | - torch = try_torch_import(require=True) |
840 | | - assert torch is not None |
841 | | - return torch |
842 | | - |
843 | | - |
844 | | -def array_exp( |
845 | | - x: Array, |
846 | | -) -> Array: |
847 | | - """ |
848 | | - Calculate the exponential of array elements. |
849 | | -
|
850 | | - Args: |
851 | | - x: Input array. |
852 | | -
|
853 | | - Returns: |
854 | | - Exponential of each element in the input array. |
855 | | - """ |
856 | | - if is_tensor(x): |
857 | | - assert torch is not None |
858 | | - return cast(Array, torch.exp(cast(Tensor, x))) |
859 | | - else: # Fallback to numpy approach |
860 | | - return cast(Array, np.exp(to_numpy(x))) |
861 | | - |
862 | | - |
863 | 586 | def array_count_nonzero( |
864 | 587 | x: Array, |
865 | 588 | ) -> int: |
|
0 commit comments