Skip to content

Commit 5557cbc

Browse files
committed
add recursive convert
1 parent 2052823 commit 5557cbc

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

tests/test_torchplot.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import string
1514
from collections import namedtuple
1615
from inspect import getmembers, isfunction
1716

@@ -50,7 +49,10 @@ def test_members(member):
5049
@pytest.mark.parametrize("test_case", _cases)
5150
def test_cpu(test_case):
5251
""" test that it works on cpu """
52+
# passed as args
5353
assert tp.plot(test_case.x, test_case.y, ".")
54+
# passed as kwargs
55+
assert tp.scatter(x=test_case.x, y=test_case.y)
5456

5557

5658
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@@ -60,5 +62,9 @@ def test_gpu(test_case):
6062
assert tp.plot(
6163
test_case.x.cuda() if isinstance(test_case.x, torch.Tensor) else test_case.x,
6264
test_case.y.cuda() if isinstance(test_case.y, torch.Tensor) else test_case.y,
63-
".",
65+
)
66+
# passed as kwargs
67+
assert tp.scatter(
68+
x=test_case.x.cuda() if isinstance(test_case.x, torch.Tensor) else test_case.x,
69+
y=test_case.y.cuda() if isinstance(test_case.y, torch.Tensor) else test_case.y,
6470
)

torchplot/core.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,63 @@
11
#!/usr/bin/env python
22
from inspect import getdoc, getmembers, isfunction
3+
from typing import Any, Callable, Mapping, Sequence, Union
34

45
import matplotlib.pyplot as plt
56
import torch
67

78

9+
# Taken from
10+
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/apply_func.py
11+
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
12+
"""
13+
Recursively applies a function to all elements of a certain dtype.
14+
Args:
15+
data: the collection to apply the function to
16+
dtype: the given function will be applied to all elements of this dtype
17+
function: the function to apply
18+
*args: positional arguments (will be forwarded to calls of ``function``)
19+
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
20+
Returns:
21+
the resulting collection
22+
"""
23+
elem_type = type(data)
24+
25+
# Breaking condition
26+
if isinstance(data, dtype):
27+
return function(data, *args, **kwargs)
28+
29+
# Recursively apply to collection items
30+
if isinstance(data, Mapping):
31+
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
32+
33+
if isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple
34+
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
35+
36+
if isinstance(data, Sequence) and not isinstance(data, str):
37+
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
38+
39+
# data is neither of dtype, nor a collection
40+
return data
41+
42+
843
# Function to convert a list of arguments containing torch tensors, into
944
# a corresponding list of arguments containing numpy arrays
1045
def _torch2np(*args, **kwargs):
46+
"""
47+
Convert a list of arguments containing torch tensors into a list of
48+
arguments containing numpy arrays
49+
"""
50+
1151
def convert(arg):
12-
return arg.detach().cpu().numpy() if isinstance(arg, torch.Tensor) else arg
52+
return arg.detach().cpu().numpy()
1353

1454
# first unnamed arguments
15-
outargs = [convert(arg) for arg in args]
55+
outargs = apply_to_collection(args, torch.Tensor, convert)
1656

1757
# then keyword arguments
18-
outkwargs = dict()
19-
for key, value in kwargs.items():
20-
outkwargs[key] = convert(value)
58+
outkwargs = apply_to_collection(kwargs, torch.Tensor, convert)
2159

22-
return outargs, kwargs
60+
return outargs, outkwargs
2361

2462

2563
# Iterate over all members of 'plt' in order to duplicate them

0 commit comments

Comments
 (0)