Skip to content

Commit 12c3785

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add pytorch support to dataclass_array
PiperOrigin-RevId: 509333699
1 parent 6a2c54e commit 12c3785

File tree

9 files changed

+140
-30
lines changed

9 files changed

+140
-30
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add `torch` support (experimental). Require to call
27+
`dca.activate_torch_support()`
28+
2629
## [1.3.0] - 2023-01-16
2730

2831
* Added: Support for static `dca.DataclassArray` (dataclasses with only

dataclass_array/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from dataclass_array.ops import stack
3838
from dataclass_array.vectorization import vectorize_method
3939

40+
# TODO(epot): Remove once Torch has better numpy API
41+
from etils.enp import activate_torch_support
42+
4043
# `dca.testing` do not depend on pytest or other heavy deps, so is safe to
4144
# import
4245
from dataclass_array import testing

dataclass_array/array_dataclass.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ def __init_subclass__(cls, **kwargs):
195195
super().__init_subclass__(**kwargs)
196196
# TODO(epot): Could have smart __repr__ which display types if array have
197197
# too many values (maybe directly in `edc.field(repr=...)`).
198-
edc.dataclass(kw_only=True, repr=True)(cls)
199-
cls._dca_tree_map_registered = False
198+
edc.dataclass(kw_only=True, repr=True, auto_cast=False)(cls)
199+
cls._dca_jax_tree_registered = False
200+
cls._dca_torch_tree_registered = False
200201
# Typing annotations have to be lazily evaluated (to support
201202
# `from __future__ import annotations` and forward reference)
202203
# To avoid costly `typing.get_type_hints` which perform `eval` and `str`
@@ -217,10 +218,20 @@ def __post_init__(self) -> None:
217218
_init_cls(self)
218219

219220
# Register the tree_map here instead of `__init_subclass__` as `jax` may
220-
# not have been registered yet during import
221-
if enp.lazy.has_jax and not cls._dca_tree_map_registered: # pylint: disable=protected-access
221+
# not have been imported yet during import.
222+
if enp.lazy.has_jax and not cls._dca_jax_tree_registered: # pylint: disable=protected-access
222223
enp.lazy.jax.tree_util.register_pytree_node_class(cls)
223-
cls._dca_tree_map_registered = True # pylint: disable=protected-access
224+
cls._dca_jax_tree_registered = True # pylint: disable=protected-access
225+
226+
if enp.lazy.has_torch and not cls._dca_torch_tree_registered: # pylint: disable=protected-access
227+
# Note: Torch is updating it's tree API to make it public and use `optree`
228+
# as backend: https://github.com/pytorch/pytorch/issues/65761
229+
enp.lazy.torch.utils._pytree._register_pytree_node( # pylint: disable=protected-access
230+
cls,
231+
flatten_fn=lambda a: a.tree_flatten(),
232+
unflatten_fn=lambda vals, ctx: cls.tree_unflatten(ctx, vals),
233+
)
234+
cls._dca_torch_tree_registered = True # pylint: disable=protected-access
224235

225236
# Validate and normalize array fields
226237
# * Maybe cast (list, np) -> xnp
@@ -442,14 +453,28 @@ def as_tf(self: _DcT) -> _DcT:
442453
"""Returns the instance as containing `tf.Tensor`."""
443454
return self.as_xnp(enp.lazy.tnp)
444455

456+
def as_torch(self: _DcT) -> _DcT:
457+
"""Returns the instance as containing `torch.Tensor`."""
458+
return self.as_xnp(enp.lazy.torch)
459+
445460
def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
446461
"""Returns the instance as containing `xnp.ndarray`."""
447462
if xnp is self.xnp: # No-op
448463
return self
464+
# Direct `torch` <> `tf`/`jax` conversion not supported, so convert to
465+
# `numpy`
466+
if (
467+
enp.lazy.has_torch
468+
and xnp is enp.lazy.torch
469+
or self.xnp is enp.lazy.torch
470+
):
471+
array_fn = lambda f: xnp.asarray(np.asarray(f.value))
472+
else:
473+
array_fn = lambda f: xnp.asarray(f.value)
449474

450475
# Update all childs
451476
new_self = self._map_field(
452-
array_fn=lambda f: xnp.asarray(f.value),
477+
array_fn=array_fn,
453478
dc_fn=lambda f: f.value.as_xnp(xnp),
454479
)
455480
return new_self
@@ -518,6 +543,17 @@ def _get_xnp(f: _ArrayField) -> enp.NpModule:
518543
return None
519544
xnp = _infer_xnp(xnps)
520545

546+
if (
547+
enp.lazy.has_torch
548+
and xnp is enp.lazy.torch
549+
and not hasattr(enp.lazy.torch, '__etils_np_mode__')
550+
):
551+
raise ValueError(
552+
'torch support currently require to call:\n'
553+
'import dataclass_array as dca\n'
554+
'dca.activate_torch_support()'
555+
)
556+
521557
def _cast_field(f: _ArrayField) -> None:
522558
try:
523559
new_value = np_utils.asarray(

dataclass_array/array_dataclass_test.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tensorflow as tf
2929

3030
# Activate the fixture
31-
set_tnp = enp.testing.set_tnp
31+
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
3232

3333
# TODO(epot): Test dtype `complex`, `str`
3434

@@ -70,8 +70,8 @@ def assert_val(p: Point, shape: Shape, xnp: enp.NpModule = None):
7070
_assert_common(p, shape=shape, xnp=xnp)
7171
assert p.x.shape == shape
7272
assert p.y.shape == shape
73-
assert p.x.dtype == np.float32
74-
assert p.y.dtype == np.float32
73+
assert enp.lazy.as_dtype(p.x.dtype) == np.float32
74+
assert enp.lazy.as_dtype(p.y.dtype) == np.float32
7575
assert isinstance(p.x, xnp.ndarray)
7676
assert isinstance(p.y, xnp.ndarray)
7777

@@ -98,8 +98,8 @@ def assert_val(p: Isometrie, shape: Shape, xnp: enp.NpModule = None):
9898
_assert_common(p, shape=shape, xnp=xnp)
9999
assert p.r.shape == shape + (3, 3)
100100
assert p.t.shape == shape + (2,)
101-
assert p.r.dtype == np.float32
102-
assert p.t.dtype == np.int32
101+
assert enp.lazy.as_dtype(p.r.dtype) == np.float32
102+
assert enp.lazy.as_dtype(p.t.dtype) == np.int32
103103
assert isinstance(p.r, xnp.ndarray)
104104
assert isinstance(p.t, xnp.ndarray)
105105

@@ -226,8 +226,8 @@ def assert_val(p: WithStatic, shape: Shape, xnp: enp.NpModule = None):
226226
NestedOnlyStatic.assert_val(p.nested_static, shape, xnp=xnp)
227227
assert p.x.shape == shape + (3,)
228228
assert p.y.shape == shape + (2, 2)
229-
assert p.x.dtype == np.float32
230-
assert p.y.dtype == np.float32
229+
assert enp.lazy.as_dtype(p.x.dtype) == np.float32
230+
assert enp.lazy.as_dtype(p.y.dtype) == np.float32
231231
assert isinstance(p.x, xnp.ndarray)
232232
assert isinstance(p.y, xnp.ndarray)
233233
# Static field is correctly forwarded
@@ -546,12 +546,15 @@ def test_convert(
546546
):
547547
p = dca_cls.make(xnp=xnp, shape=(2,))
548548
assert p.xnp is xnp
549+
549550
assert p.as_np().xnp is enp.lazy.np
550551
assert p.as_jax().xnp is enp.lazy.jnp
551552
assert p.as_tf().xnp is enp.lazy.tnp
553+
assert p.as_torch().xnp is enp.lazy.torch
552554
assert p.as_xnp(np).xnp is enp.lazy.np
553555
assert p.as_xnp(enp.lazy.jnp).xnp is enp.lazy.jnp
554556
assert p.as_xnp(enp.lazy.tnp).xnp is enp.lazy.tnp
557+
assert p.as_xnp(enp.lazy.torch).xnp is enp.lazy.torch
555558
# Make sure the nested class are also updated
556559
dca_cls.assert_val(p.as_jax(), (2,), xnp=enp.lazy.jnp)
557560

@@ -587,24 +590,44 @@ def test_infer_np(xnp: enp.NpModule):
587590

588591

589592
@parametrize_dataclass_arrays
590-
def test_jax_tree_map(dca_cls: DcaTest):
593+
@pytest.mark.parametrize(
594+
'tree_map',
595+
[
596+
enp.lazy.jax.tree_map,
597+
enp.lazy.torch.utils._pytree.tree_map,
598+
],
599+
)
600+
def test_torch_tree_map(tree_map, dca_cls: DcaTest):
591601
p = dca_cls.make(shape=(3,), xnp=np)
592-
p = enp.lazy.jax.tree_map(lambda x: x[None, ...], p)
602+
p = tree_map(lambda x: x[None, ...], p)
593603
dca_cls.assert_val(p, (1, 3), xnp=np)
594604

595605

596-
def test_jax_vmap():
606+
@enp.testing.parametrize_xnp(
607+
restrict=[
608+
'jnp',
609+
'torch',
610+
]
611+
)
612+
def test_vmap(xnp: enp.NpModule):
613+
import functorch
614+
615+
vmap_fn = {
616+
enp.lazy.jnp: enp.lazy.jax.vmap,
617+
enp.lazy.torch: functorch.vmap,
618+
}[xnp]
619+
597620
batch_shape = 3
598621

599-
@enp.lazy.jax.vmap
622+
@vmap_fn
600623
def fn(p: WithStatic) -> WithStatic:
601624
assert isinstance(p, WithStatic)
602625
assert p.shape == () # pylint:disable=g-explicit-bool-comparison
603626
return p.replace(x=p.x + 1)
604627

605-
x = WithStatic.make((batch_shape,), xnp=enp.lazy.jnp)
628+
x = WithStatic.make((batch_shape,), xnp=xnp)
606629
y = fn(x)
607-
WithStatic.assert_val(y, (batch_shape,), xnp=enp.lazy.jnp)
630+
WithStatic.assert_val(y, (batch_shape,), xnp=xnp)
608631
# pos was updated
609632
np.testing.assert_allclose(y.x, np.ones((batch_shape, 3)))
610633
np.testing.assert_allclose(y.y, np.zeros((batch_shape, 2, 2)))
@@ -628,8 +651,8 @@ class PointNoCast(dca.DataclassArray):
628651
y=xnp.array([1, 2, 3], dtype=np.uint8),
629652
)
630653
assert p.shape == (3,)
631-
assert p.x.dtype == np.float16
632-
assert p.y.dtype == np.uint8
654+
assert enp.lazy.as_dtype(p.x.dtype) == np.float16
655+
assert enp.lazy.as_dtype(p.y.dtype) == np.uint8
633656

634657

635658
@enp.testing.parametrize_xnp()
@@ -689,7 +712,13 @@ class PointDynamicShape(dca.DataclassArray):
689712
assert dca.stack([p, p]).shape == (2,) + batch_shape
690713

691714
# Incompatible shape will raise an error
692-
with pytest.raises((ValueError, tf.errors.InvalidArgumentError)):
715+
expected_exception_cls = {
716+
enp.lazy.np: ValueError,
717+
enp.lazy.jnp: ValueError,
718+
enp.lazy.tnp: tf.errors.InvalidArgumentError,
719+
enp.lazy.torch: RuntimeError,
720+
}
721+
with pytest.raises(expected_exception_cls[xnp]):
693722
dca.stack([p, p2])
694723

695724
if batch_shape:

dataclass_array/import_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@
2020

2121
from __future__ import annotations
2222

23+
import dataclasses
2324
import sys
2425

2526
import dataclass_array as dca
27+
from etils import enp
28+
import pytest
2629

27-
del dca
30+
31+
@dataclasses.dataclass(frozen=True)
32+
class A(dca.DataclassArray):
33+
x: dca.typing.f32['*s']
2834

2935

3036
def test_lazy():
31-
pass
37+
38+
x = A(x=[1.0, 2.0])
39+
assert x.xnp is enp.lazy.np

dataclass_array/utils/np_utils_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import numpy as np
2323
import pytest
2424

25+
# Activate the fixture
26+
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
27+
2528

2629
@enp.testing.parametrize_xnp()
2730
def test_get_xnp(xnp: enp.NpModule):

dataclass_array/vectorization.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,19 @@ def _vmap_method(
311311
if xnp is enp.lazy.np:
312312
return _vmap_method_np(args, map_non_static=map_non_static)
313313
elif xnp is enp.lazy.jnp:
314-
return _vmap_method_jnp(args, map_non_static=map_non_static)
314+
return _vmap_method_jax_torch(
315+
args,
316+
map_non_static=map_non_static,
317+
make_vmap_fn=_jax_vmap_cached,
318+
)
315319
elif xnp is enp.lazy.tnp:
316320
return _vmap_method_tf(args, map_non_static=map_non_static)
321+
elif xnp is enp.lazy.torch:
322+
return _vmap_method_jax_torch(
323+
args,
324+
map_non_static=map_non_static,
325+
make_vmap_fn=_torch_vmap_cached,
326+
)
317327
raise TypeError(f'Invalid numpy module: {xnp}')
318328

319329

@@ -334,10 +344,11 @@ def _vmap_method_np(
334344
return tree_utils.tree_map(_stack, *outs)
335345

336346

337-
def _vmap_method_jnp(
347+
def _vmap_method_jax_torch(
338348
args: inspect_utils.BoundArgs[Any, _OutT],
339349
*,
340350
map_non_static: _MapNonStatic,
351+
make_vmap_fn: Any,
341352
) -> _OutT:
342353
"""vectorization using `jax` backend."""
343354

@@ -349,21 +360,35 @@ def _vmap_method_jnp(
349360
in_axes = tuple(arg.value for arg in in_axes_args)
350361

351362
# Vectorize self and args
352-
vfn = _vmap_cached(args.fn, in_axes=in_axes)
363+
vfn = make_vmap_fn(args.fn, in_axes=in_axes)
353364

354365
# Call `vfn(self, *args, **kwargs)`
355366
return args.call(vfn)
356367

357368

358369
@functools.lru_cache(maxsize=None)
359-
def _vmap_cached(fn: _FnT, *, in_axes) -> _FnT:
370+
def _jax_vmap_cached(fn: _FnT, *, in_axes) -> _FnT:
360371
"""Like `jax.vmap` but cache the function."""
361372
return enp.lazy.jax.vmap(
362373
fn,
363374
in_axes=in_axes,
364375
)
365376

366377

378+
@functools.lru_cache(maxsize=None)
379+
def _torch_vmap_cached(fn: _FnT, *, in_axes) -> _FnT:
380+
"""Like `jax.vmap` but cache the function."""
381+
try:
382+
import functorch # pylint: disable=g-import-not-at-top # pytype: disable=import-error
383+
except ImportError as e:
384+
epy.reraise(e, suffix='. vectorization with `pytorch` require functorch')
385+
386+
return functorch.vmap(
387+
fn,
388+
in_dims=in_axes,
389+
)
390+
391+
367392
def _vmap_method_tf(
368393
args: inspect_utils.BoundArgs[Any, _OutT],
369394
*,

dataclass_array/vectorization_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
X1 = 5
3535

3636
# Activate the fixture
37-
set_tnp = enp.testing.set_tnp
37+
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
3838

3939

4040
@pytest.mark.parametrize(
@@ -179,7 +179,9 @@ def fn(self):
179179
assert a.x == 5
180180

181181
# Vectorization supported
182-
if xnp != tnp:
182+
if xnp not in [
183+
tnp,
184+
]:
183185
a = a.fn()
184186
assert a.xnp is xnp
185187
assert a.shape == (3,)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dev = [
4343
"chex",
4444
"jax[cpu]",
4545
"tf-nightly",
46+
"torch",
4647
]
4748

4849
[tool.pyink]

0 commit comments

Comments
 (0)