|
2 | 2 | import numpy as np |
3 | 3 | import pytest |
4 | 4 | import torch |
| 5 | +from jax.tree_util import tree_all, tree_map |
5 | 6 |
|
6 | 7 | from s2fft.utils import torch_wrapper |
7 | 8 |
|
| 9 | +jax.config.update("jax_enable_x64", True) |
| 10 | + |
8 | 11 |
|
9 | 12 | def sum_abs_square(x: jax.Array) -> float: |
10 | 13 | return (abs(x) ** 2).sum() |
11 | 14 |
|
12 | 15 |
|
13 | | -@pytest.mark.parametrize("jax_function", [sum_abs_square]) |
14 | | -@pytest.mark.parametrize("input_shape", [(), (1,), (2,), (3, 4)]) |
15 | | -def test_wrap_as_torch_function_single_arg(rng, input_shape, jax_function): |
16 | | - x_jax = jax.numpy.asarray(rng.standard_normal(input_shape)) |
17 | | - y_jax = jax_function(x_jax) |
| 16 | +def log_sum_exp(x: jax.Array) -> float: |
| 17 | + xp = x.__array_namespace__() |
| 18 | + max_x = x.max() |
| 19 | + return max_x + xp.log(xp.exp(x - max_x).sum()) |
| 20 | + |
| 21 | + |
| 22 | +DTYPES = ["float32", "float64"] |
| 23 | + |
| 24 | +INPUT_SHAPES = [(), (1,), (2,), (3, 4)] |
| 25 | + |
| 26 | +PYTREE_STRUCTURES = [(), [(), ((1,), (2, 3))], {"a": [(1,), ()], "b": {"0": (1, 2)}}] |
| 27 | + |
| 28 | +JAX_SINGLE_ARG_FUNCTIONS = [sum_abs_square, log_sum_exp] |
| 29 | + |
| 30 | + |
| 31 | +def generate_pytree(rng, converter, dtype, structure): |
| 32 | + if isinstance(structure, tuple): |
| 33 | + if structure == () or all(isinstance(child, int) for child in structure): |
| 34 | + return converter(rng.standard_normal(structure, dtype=dtype)) |
| 35 | + else: |
| 36 | + return tuple( |
| 37 | + generate_pytree(rng, converter, dtype, child) for child in structure |
| 38 | + ) |
| 39 | + elif isinstance(structure, list): |
| 40 | + return [generate_pytree(rng, converter, dtype, child) for child in structure] |
| 41 | + elif isinstance(structure, dict): |
| 42 | + return { |
| 43 | + key: generate_pytree(rng, converter, dtype, value) |
| 44 | + for key, value in structure.items() |
| 45 | + } |
| 46 | + else: |
| 47 | + raise TypeError( |
| 48 | + f"pytree structure with type {type(structure)} not of recognised type" |
| 49 | + ) |
| 50 | + |
| 51 | + |
| 52 | +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) |
| 53 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 54 | +def test_jax_array_to_torch_tensor(rng, input_shape, dtype): |
| 55 | + x_jax = jax.numpy.asarray(rng.standard_normal(input_shape, dtype=dtype)) |
18 | 56 | x_torch = torch_wrapper.jax_array_to_torch_tensor(x_jax) |
19 | 57 | assert isinstance(x_torch, torch.Tensor) |
| 58 | + assert x_torch.dtype == getattr(torch, dtype) |
| 59 | + np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch)) |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) |
| 63 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 64 | +def test_torch_tensor_to_jax_array(rng, input_shape, dtype): |
| 65 | + x_torch = torch.from_numpy(rng.standard_normal(input_shape, dtype=dtype)) |
| 66 | + x_jax = torch_wrapper.torch_tensor_to_jax_array(x_torch) |
| 67 | + assert isinstance(x_jax, jax.Array) |
| 68 | + assert x_jax.dtype == dtype |
| 69 | + np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch)) |
| 70 | + |
| 71 | + |
| 72 | +@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES) |
| 73 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 74 | +def test_tree_map_jax_array_to_torch_tensor(rng, pytree_structure, dtype): |
| 75 | + jax_pytree = generate_pytree(rng, jax.numpy.asarray, dtype, pytree_structure) |
| 76 | + torch_pytree = torch_wrapper.tree_map_jax_array_to_torch_tensor(jax_pytree) |
| 77 | + assert tree_all( |
| 78 | + tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree), |
| 79 | + ) |
| 80 | + assert tree_all( |
| 81 | + tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree), |
| 82 | + ) |
| 83 | + assert tree_all( |
| 84 | + tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree), |
| 85 | + ) |
| 86 | + assert tree_all( |
| 87 | + tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree), |
| 88 | + ) |
| 89 | + assert tree_all( |
| 90 | + tree_map( |
| 91 | + lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)), |
| 92 | + torch_pytree, |
| 93 | + jax_pytree, |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES) |
| 99 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 100 | +def test_tree_map_torch_tensor_to_jax_array(rng, pytree_structure, dtype): |
| 101 | + torch_pytree = generate_pytree(rng, torch.from_numpy, dtype, pytree_structure) |
| 102 | + jax_pytree = torch_wrapper.tree_map_torch_tensor_to_jax_array(torch_pytree) |
| 103 | + assert tree_all( |
| 104 | + tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree), |
| 105 | + ) |
| 106 | + assert tree_all( |
| 107 | + tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree), |
| 108 | + ) |
| 109 | + assert tree_all( |
| 110 | + tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree), |
| 111 | + ) |
| 112 | + assert tree_all( |
| 113 | + tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree), |
| 114 | + ) |
| 115 | + assert tree_all( |
| 116 | + tree_map( |
| 117 | + lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)), |
| 118 | + torch_pytree, |
| 119 | + jax_pytree, |
| 120 | + ) |
| 121 | + ) |
| 122 | + |
| 123 | + |
| 124 | +@pytest.mark.parametrize("jax_function", JAX_SINGLE_ARG_FUNCTIONS) |
| 125 | +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) |
| 126 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 127 | +def test_wrap_as_torch_function_single_arg(rng, input_shape, dtype, jax_function): |
| 128 | + x_numpy = rng.standard_normal(input_shape, dtype=dtype) |
| 129 | + x_jax = jax.numpy.asarray(x_numpy) |
| 130 | + y_jax = jax_function(x_jax) |
| 131 | + x_torch = torch.tensor(x_numpy, requires_grad=True) |
20 | 132 | torch_function = torch_wrapper.wrap_as_torch_function(jax_function) |
21 | 133 | y_torch = torch_function(x_torch) |
22 | 134 | assert isinstance(y_torch, torch.Tensor) |
23 | | - np.testing.assert_allclose(np.asarray(y_jax), np.asarray(y_torch)) |
| 135 | + assert y_torch.dtype == getattr(torch, dtype) |
| 136 | + np.testing.assert_allclose(np.asarray(y_jax), np.asarray(y_torch.detach())) |
| 137 | + dy_dx_jax = jax.grad(jax_function)(x_jax) |
| 138 | + y_torch.backward() |
| 139 | + assert x_torch.grad.dtype == getattr(torch, dtype) |
| 140 | + np.testing.assert_allclose(np.asarray(dy_dx_jax), np.asarray(x_torch.grad)) |
0 commit comments