Skip to content

Commit 883a220

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add pytoch 2.0 compatibility
PiperOrigin-RevId: 517483590
1 parent 4cfa827 commit 883a220

File tree

6 files changed

+55
-22
lines changed

6 files changed

+55
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add `torch==2.0.0` support
27+
2628
## [1.4.0] - 2023-03-13
2729

2830
* **Add `torch` support!**

dataclass_array/array_dataclass.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,18 @@ def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
515515
return self
516516
# Direct `torch` <> `tf`/`jax` conversion not supported, so convert to
517517
# `numpy`
518-
if enp.lazy.has_torch and (
519-
xnp is enp.lazy.torch or self.xnp is enp.lazy.torch
520-
):
521-
array_fn = lambda f: xnp.asarray(np.asarray(f.value))
518+
if enp.lazy.is_torch_xnp(xnp) or enp.lazy.is_torch_xnp(self.xnp):
519+
520+
def _as_torch(f):
521+
arr = np.asarray(f.value)
522+
# Torch fail for scalar arrays:
523+
# https://github.com/pytorch/pytorch/issues/97021
524+
if enp.lazy.is_torch_xnp(xnp) and not arr.shape: # Destination is torch
525+
return xnp.asarray(arr.item(), dtype=lazy.as_torch_dtype(arr.dtype))
526+
527+
return xnp.asarray(arr)
528+
529+
array_fn = _as_torch
522530
else:
523531
array_fn = lambda f: xnp.asarray(f.value)
524532

dataclass_array/array_dataclass_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,9 @@ def test_torch_device():
609609
]
610610
)
611611
def test_vmap(xnp: enp.NpModule):
612-
import functorch
613-
614612
vmap_fn = {
615613
enp.lazy.jnp: enp.lazy.jax.vmap,
616-
enp.lazy.torch: functorch.vmap,
614+
enp.lazy.torch: enp.lazy.torch.func.vmap,
617615
}[xnp]
618616

619617
batch_shape = 3

dataclass_array/testing.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from dataclass_array import array_dataclass
2323
from dataclass_array.typing import FloatArray # pylint: disable=g-multiple-import
24+
from etils import enp
2425
from etils.etree import jax as etree
2526
from etils.etree import Tree
2627
import numpy as np
@@ -61,6 +62,9 @@ def assert_allclose(
6162
def assert_array_equal(
6263
x,
6364
y,
65+
*,
66+
atol: Optional[float] = None,
67+
rtol: Optional[float] = None,
6468
) -> None:
6569
"""Assert the 2 objects are equals.
6670
@@ -71,9 +75,26 @@ def assert_array_equal(
7175
Args:
7276
x: First element to compare
7377
y: Second element to compare
78+
atol: Absolute tolerance
79+
rtol: Relative tolerance
7480
"""
7581
assert type(x) == type(y) # pylint: disable=unidiomatic-typecheck
7682
assert x.shape == y.shape
77-
assert_allclose(x, y)
83+
assert_allclose(x, y, atol=atol, rtol=rtol)
7884
if isinstance(x, array_dataclass.DataclassArray):
7985
assert x.xnp is y.xnp
86+
87+
88+
def skip_vmap_unavailable(xnp: enp.NpModule, *, skip_torch: str = '') -> None:
89+
"""Skip the test when vmap not available."""
90+
skip = False
91+
if enp.lazy.is_tf_xnp(xnp):
92+
# TODO(b/152678472): TF do not support vmap & tf.nest
93+
skip = True
94+
elif enp.lazy.is_torch_xnp(xnp):
95+
if skip_torch:
96+
skip = True
97+
if skip:
98+
import pytest # pylint: disable=g-import-not-at-top # pytype: disable=import-error
99+
100+
pytest.skip('Vectorization not supported yet with TF / Torch')

dataclass_array/vectorization.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,18 @@ def _jax_vmap_cached(fn: _FnT, *, in_axes) -> _FnT:
377377
@functools.lru_cache(maxsize=None)
378378
def _torch_vmap_cached(fn: _FnT, *, in_axes) -> _FnT:
379379
"""Like `jax.vmap` but cache the function."""
380-
try:
381-
import functorch # pylint: disable=g-import-not-at-top # pytype: disable=import-error
382-
except ImportError as e:
383-
epy.reraise(e, suffix='. vectorization with `pytorch` require functorch')
380+
if hasattr(enp.lazy.torch, 'func'): # torch 2.0
381+
vmap = enp.lazy.torch.func.vmap
382+
else:
383+
try:
384+
import functorch # pylint: disable=g-import-not-at-top # pytype: disable=import-error
385+
except ImportError as e:
386+
epy.reraise(
387+
e, suffix='. vectorization with `pytorch<2` require functorch'
388+
)
389+
vmap = functorch.vmap
384390

385-
return functorch.vmap(
391+
return vmap(
386392
fn,
387393
in_dims=in_axes,
388394
)

dataclass_array/vectorization_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from etils import enp
2727
import jax
2828
import pytest
29-
import tensorflow.experimental.numpy as tnp
3029

3130
H = 2
3231
W = 3
@@ -177,16 +176,15 @@ def fn(self):
177176
assert a.shape == (3,)
178177
assert a.x == 5
179178

180-
# Vectorization supported
181-
if xnp not in [
182-
tnp,
183-
]:
184-
a = a.fn()
185-
assert a.xnp is xnp
179+
# Tree-map supported
180+
a = jax.tree_util.tree_map(lambda x: x, a)
186181
assert a.shape == (3,)
187182
assert a.x == 5
188183

189-
# Tree-map supported
190-
a = jax.tree_util.tree_map(lambda x: x, a)
184+
# Vectorization supported
185+
dca.testing.skip_vmap_unavailable(xnp)
186+
187+
a = a.fn()
188+
assert a.xnp is xnp
191189
assert a.shape == (3,)
192190
assert a.x == 5

0 commit comments

Comments
 (0)