3737import typing_extensions
3838from typing_extensions import Annotated , Literal , TypeAlias # pylint: disable=g-multiple-import
3939
40+ if typing .TYPE_CHECKING :
41+ import torch # pytype: disable=import-error
42+
4043lazy = enp .lazy
4144
4245# TODO(pytype): Should use `dca.typing.DcT` but bound does not work across
@@ -477,8 +480,6 @@ def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
477480 )
478481 return new_self
479482
480- # ====== Internal ======
481-
482483 # TODO(pytype): Remove hack. Currently, Python does not support typing
483484 # annotations for modules, by pytype auto-infer the correct type.
484485 # So this hack allow auto-completion
@@ -497,6 +498,33 @@ def xnp(self) -> enp.NpModule:
497498 """Returns the numpy module of the class (np, jnp, tnp)."""
498499 return self ._xnp
499500
501+ # ====== Torch specific methods ======
502+ # Could also add
503+ # * x.detach
504+ # * x.is_cuda
505+ # * x.device
506+ # * x.get_device
507+
508+ def to (self : _DcT , device , ** kwargs ) -> _DcT :
509+ """Move the dataclass array to the device."""
510+ if not lazy .is_torch_xnp (self .xnp ):
511+ raise ValueError ('`.to` can only be called when `xnp == torch`' )
512+ return self .map_field (lambda f : f .to (device , ** kwargs ))
513+
514+ def cpu (self : _DcT , * args , ** kwargs ) -> _DcT :
515+ """Move the dataclass array to the CPU device."""
516+ if not lazy .is_torch_xnp (self .xnp ):
517+ raise ValueError ('`.cpu` can only be called when `xnp == torch`' )
518+ return self .map_field (lambda f : f .cpu (* args , ** kwargs ))
519+
520+ def cuda (self : _DcT , * args , ** kwargs ) -> _DcT :
521+ """Move the dataclass array to the CUDA device."""
522+ if not lazy .is_torch_xnp (self .xnp ):
523+ raise ValueError ('`.cuda` can only be called when `xnp == torch`' )
524+ return self .map_field (lambda f : f .cuda (* args , ** kwargs ))
525+
526+ # ====== Internal ======
527+
500528 @epy .cached_property
501529 def _all_array_fields (self ) -> dict [str , _ArrayField ]:
502530 """All array fields, including `None` values."""
0 commit comments