Skip to content

Commit bd49bf7

Browse files
committed
Add additional torch wrapper tests
1 parent 1ece73b commit bd49bf7

File tree

1 file changed

+123
-6
lines changed

1 file changed

+123
-6
lines changed

tests/test_torch_wrapper.py

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,139 @@
22
import numpy as np
33
import pytest
44
import torch
5+
from jax.tree_util import tree_all, tree_map
56

67
from s2fft.utils import torch_wrapper
78

9+
jax.config.update("jax_enable_x64", True)
10+
811

912
def sum_abs_square(x: jax.Array) -> float:
1013
return (abs(x) ** 2).sum()
1114

1215

13-
@pytest.mark.parametrize("jax_function", [sum_abs_square])
14-
@pytest.mark.parametrize("input_shape", [(), (1,), (2,), (3, 4)])
15-
def test_wrap_as_torch_function_single_arg(rng, input_shape, jax_function):
16-
x_jax = jax.numpy.asarray(rng.standard_normal(input_shape))
17-
y_jax = jax_function(x_jax)
16+
def log_sum_exp(x: jax.Array) -> float:
17+
xp = x.__array_namespace__()
18+
max_x = x.max()
19+
return max_x + xp.log(xp.exp(x - max_x).sum())
20+
21+
22+
DTYPES = ["float32", "float64"]
23+
24+
INPUT_SHAPES = [(), (1,), (2,), (3, 4)]
25+
26+
PYTREE_STRUCTURES = [(), [(), ((1,), (2, 3))], {"a": [(1,), ()], "b": {"0": (1, 2)}}]
27+
28+
JAX_SINGLE_ARG_FUNCTIONS = [sum_abs_square, log_sum_exp]
29+
30+
31+
def generate_pytree(rng, converter, dtype, structure):
32+
if isinstance(structure, tuple):
33+
if structure == () or all(isinstance(child, int) for child in structure):
34+
return converter(rng.standard_normal(structure, dtype=dtype))
35+
else:
36+
return tuple(
37+
generate_pytree(rng, converter, dtype, child) for child in structure
38+
)
39+
elif isinstance(structure, list):
40+
return [generate_pytree(rng, converter, dtype, child) for child in structure]
41+
elif isinstance(structure, dict):
42+
return {
43+
key: generate_pytree(rng, converter, dtype, value)
44+
for key, value in structure.items()
45+
}
46+
else:
47+
raise TypeError(
48+
f"pytree structure with type {type(structure)} not of recognised type"
49+
)
50+
51+
52+
@pytest.mark.parametrize("input_shape", INPUT_SHAPES)
53+
@pytest.mark.parametrize("dtype", DTYPES)
54+
def test_jax_array_to_torch_tensor(rng, input_shape, dtype):
55+
x_jax = jax.numpy.asarray(rng.standard_normal(input_shape, dtype=dtype))
1856
x_torch = torch_wrapper.jax_array_to_torch_tensor(x_jax)
1957
assert isinstance(x_torch, torch.Tensor)
58+
assert x_torch.dtype == getattr(torch, dtype)
59+
np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch))
60+
61+
62+
@pytest.mark.parametrize("input_shape", INPUT_SHAPES)
63+
@pytest.mark.parametrize("dtype", DTYPES)
64+
def test_torch_tensor_to_jax_array(rng, input_shape, dtype):
65+
x_torch = torch.from_numpy(rng.standard_normal(input_shape, dtype=dtype))
66+
x_jax = torch_wrapper.torch_tensor_to_jax_array(x_torch)
67+
assert isinstance(x_jax, jax.Array)
68+
assert x_jax.dtype == dtype
69+
np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch))
70+
71+
72+
@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES)
73+
@pytest.mark.parametrize("dtype", DTYPES)
74+
def test_tree_map_jax_array_to_torch_tensor(rng, pytree_structure, dtype):
75+
jax_pytree = generate_pytree(rng, jax.numpy.asarray, dtype, pytree_structure)
76+
torch_pytree = torch_wrapper.tree_map_jax_array_to_torch_tensor(jax_pytree)
77+
assert tree_all(
78+
tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree),
79+
)
80+
assert tree_all(
81+
tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree),
82+
)
83+
assert tree_all(
84+
tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree),
85+
)
86+
assert tree_all(
87+
tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree),
88+
)
89+
assert tree_all(
90+
tree_map(
91+
lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)),
92+
torch_pytree,
93+
jax_pytree,
94+
)
95+
)
96+
97+
98+
@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES)
99+
@pytest.mark.parametrize("dtype", DTYPES)
100+
def test_tree_map_torch_tensor_to_jax_array(rng, pytree_structure, dtype):
101+
torch_pytree = generate_pytree(rng, torch.from_numpy, dtype, pytree_structure)
102+
jax_pytree = torch_wrapper.tree_map_torch_tensor_to_jax_array(torch_pytree)
103+
assert tree_all(
104+
tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree),
105+
)
106+
assert tree_all(
107+
tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree),
108+
)
109+
assert tree_all(
110+
tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree),
111+
)
112+
assert tree_all(
113+
tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree),
114+
)
115+
assert tree_all(
116+
tree_map(
117+
lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)),
118+
torch_pytree,
119+
jax_pytree,
120+
)
121+
)
122+
123+
124+
@pytest.mark.parametrize("jax_function", JAX_SINGLE_ARG_FUNCTIONS)
125+
@pytest.mark.parametrize("input_shape", INPUT_SHAPES)
126+
@pytest.mark.parametrize("dtype", DTYPES)
127+
def test_wrap_as_torch_function_single_arg(rng, input_shape, dtype, jax_function):
128+
x_numpy = rng.standard_normal(input_shape, dtype=dtype)
129+
x_jax = jax.numpy.asarray(x_numpy)
130+
y_jax = jax_function(x_jax)
131+
x_torch = torch.tensor(x_numpy, requires_grad=True)
20132
torch_function = torch_wrapper.wrap_as_torch_function(jax_function)
21133
y_torch = torch_function(x_torch)
22134
assert isinstance(y_torch, torch.Tensor)
23-
np.testing.assert_allclose(np.asarray(y_jax), np.asarray(y_torch))
135+
assert y_torch.dtype == getattr(torch, dtype)
136+
np.testing.assert_allclose(np.asarray(y_jax), np.asarray(y_torch.detach()))
137+
dy_dx_jax = jax.grad(jax_function)(x_jax)
138+
y_torch.backward()
139+
assert x_torch.grad.dtype == getattr(torch, dtype)
140+
np.testing.assert_allclose(np.asarray(dy_dx_jax), np.asarray(x_torch.grad))

0 commit comments

Comments
 (0)