Skip to content

Commit 4c54d15

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add methods for torch device
PiperOrigin-RevId: 515367102
1 parent fe8a32d commit 4c54d15

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424
## [Unreleased]
2525

2626
* **Add `torch` support!**
27+
* Add `.cpu()`, `.cuda()`, `.to()` methods to move the dataclass from
28+
devices when using torch.
2729

2830
## [1.3.0] - 2023-01-16
2931

dataclass_array/array_dataclass.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
import typing_extensions
3838
from typing_extensions import Annotated, Literal, TypeAlias # pylint: disable=g-multiple-import
3939

40+
if typing.TYPE_CHECKING:
41+
import torch # pytype: disable=import-error
42+
4043
lazy = enp.lazy
4144

4245
# TODO(pytype): Should use `dca.typing.DcT` but bound does not work across
@@ -477,8 +480,6 @@ def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
477480
)
478481
return new_self
479482

480-
# ====== Internal ======
481-
482483
# TODO(pytype): Remove hack. Currently, Python does not support typing
483484
# annotations for modules, by pytype auto-infer the correct type.
484485
# So this hack allow auto-completion
@@ -497,6 +498,33 @@ def xnp(self) -> enp.NpModule:
497498
"""Returns the numpy module of the class (np, jnp, tnp)."""
498499
return self._xnp
499500

501+
# ====== Torch specific methods ======
502+
# Could also add
503+
# * x.detach
504+
# * x.is_cuda
505+
# * x.device
506+
# * x.get_device
507+
508+
def to(self: _DcT, device, **kwargs) -> _DcT:
509+
"""Move the dataclass array to the device."""
510+
if not lazy.is_torch_xnp(self.xnp):
511+
raise ValueError('`.to` can only be called when `xnp == torch`')
512+
return self.map_field(lambda f: f.to(device, **kwargs))
513+
514+
def cpu(self: _DcT, *args, **kwargs) -> _DcT:
515+
"""Move the dataclass array to the CPU device."""
516+
if not lazy.is_torch_xnp(self.xnp):
517+
raise ValueError('`.cpu` can only be called when `xnp == torch`')
518+
return self.map_field(lambda f: f.cpu(*args, **kwargs))
519+
520+
def cuda(self: _DcT, *args, **kwargs) -> _DcT:
521+
"""Move the dataclass array to the CUDA device."""
522+
if not lazy.is_torch_xnp(self.xnp):
523+
raise ValueError('`.cuda` can only be called when `xnp == torch`')
524+
return self.map_field(lambda f: f.cuda(*args, **kwargs))
525+
526+
# ====== Internal ======
527+
500528
@epy.cached_property
501529
def _all_array_fields(self) -> dict[str, _ArrayField]:
502530
"""All array fields, including `None` values."""

dataclass_array/array_dataclass_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,12 +597,19 @@ def test_infer_np(xnp: enp.NpModule):
597597
enp.lazy.torch.utils._pytree.tree_map,
598598
],
599599
)
600-
def test_torch_tree_map(tree_map, dca_cls: DcaTest):
600+
def test_tree_map(tree_map, dca_cls: DcaTest):
601601
p = dca_cls.make(shape=(3,), xnp=np)
602602
p = tree_map(lambda x: x[None, ...], p)
603603
dca_cls.assert_val(p, (1, 3), xnp=np)
604604

605605

606+
def test_torch_device():
607+
p = Nested.make(shape=(2,), xnp=enp.lazy.torch)
608+
p = p.cpu()
609+
p = p.to('cpu')
610+
Nested.assert_val(p, (2,), xnp=enp.lazy.torch)
611+
612+
606613
@enp.testing.parametrize_xnp(
607614
restrict=[
608615
'jnp',

0 commit comments

Comments
 (0)