|
1 | 1 | #!/usr/bin/env python
|
2 | 2 | from inspect import getdoc, getmembers, isfunction
|
| 3 | +from typing import Any, Callable, Mapping, Sequence, Union |
3 | 4 |
|
4 | 5 | import matplotlib.pyplot as plt
|
5 | 6 | import torch
|
6 | 7 |
|
7 | 8 |
|
| 9 | +# Taken from |
| 10 | +# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/apply_func.py |
| 11 | +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: |
| 12 | + """ |
| 13 | + Recursively applies a function to all elements of a certain dtype. |
| 14 | + Args: |
| 15 | + data: the collection to apply the function to |
| 16 | + dtype: the given function will be applied to all elements of this dtype |
| 17 | + function: the function to apply |
| 18 | + *args: positional arguments (will be forwarded to calls of ``function``) |
| 19 | + **kwargs: keyword arguments (will be forwarded to calls of ``function``) |
| 20 | + Returns: |
| 21 | + the resulting collection |
| 22 | + """ |
| 23 | + elem_type = type(data) |
| 24 | + |
| 25 | + # Breaking condition |
| 26 | + if isinstance(data, dtype): |
| 27 | + return function(data, *args, **kwargs) |
| 28 | + |
| 29 | + # Recursively apply to collection items |
| 30 | + if isinstance(data, Mapping): |
| 31 | + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) |
| 32 | + |
| 33 | + if isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple |
| 34 | + return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) |
| 35 | + |
| 36 | + if isinstance(data, Sequence) and not isinstance(data, str): |
| 37 | + return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) |
| 38 | + |
| 39 | + # data is neither of dtype, nor a collection |
| 40 | + return data |
| 41 | + |
| 42 | + |
8 | 43 | # Function to convert a list of arguments containing torch tensors, into
|
9 | 44 | # a corresponding list of arguments containing numpy arrays
|
10 | 45 | def _torch2np(*args, **kwargs):
|
| 46 | + """ |
| 47 | + Convert a list of arguments containing torch tensors into a list of |
| 48 | + arguments containing numpy arrays |
| 49 | + """ |
| 50 | + |
11 | 51 | def convert(arg):
|
12 |
| - return arg.detach().cpu().numpy() if isinstance(arg, torch.Tensor) else arg |
| 52 | + return arg.detach().cpu().numpy() |
13 | 53 |
|
14 | 54 | # first unnamed arguments
|
15 |
| - outargs = [convert(arg) for arg in args] |
| 55 | + outargs = apply_to_collection(args, torch.Tensor, convert) |
16 | 56 |
|
17 | 57 | # then keyword arguments
|
18 |
| - outkwargs = dict() |
19 |
| - for key, value in kwargs.items(): |
20 |
| - outkwargs[key] = convert(value) |
| 58 | + outkwargs = apply_to_collection(kwargs, torch.Tensor, convert) |
21 | 59 |
|
22 |
| - return outargs, kwargs |
| 60 | + return outargs, outkwargs |
23 | 61 |
|
24 | 62 |
|
25 | 63 | # Iterate over all members of 'plt' in order to duplicate them
|
|
0 commit comments