File tree Expand file tree Collapse file tree 7 files changed +4
-20
lines changed
Expand file tree Collapse file tree 7 files changed +4
-20
lines changed Original file line number Diff line number Diff line change @@ -23,8 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323
2424## [ Unreleased]
2525
26- * Add ` torch ` support (experimental). Require to call
27- ` dca.activate_torch_support() `
26+ * ** Add ` torch ` support!**
2827
2928## [ 1.3.0] - 2023-01-16
3029
Original file line number Diff line number Diff line change 3737from dataclass_array .ops import stack
3838from dataclass_array .vectorization import vectorize_method
3939
40- # TODO(epot): Remove once Torch has better numpy API
41- from etils .enp import activate_torch_support
42-
4340# `dca.testing` do not depend on pytest or other heavy deps, so is safe to
4441# import
4542from dataclass_array import testing
Original file line number Diff line number Diff line change @@ -541,17 +541,6 @@ def _get_xnp(f: _ArrayField) -> enp.NpModule:
541541 return None
542542 xnp = _infer_xnp (xnps )
543543
544- if (
545- enp .lazy .has_torch
546- and xnp is enp .lazy .torch
547- and not hasattr (enp .lazy .torch , '__etils_np_mode__' )
548- ):
549- raise ValueError (
550- 'torch support currently require to call:\n '
551- 'import dataclass_array as dca\n '
552- 'dca.activate_torch_support()'
553- )
554-
555544 def _cast_field (f : _ArrayField ) -> None :
556545 try :
557546 new_value = np_utils .asarray (
Original file line number Diff line number Diff line change 2828import tensorflow as tf
2929
3030# Activate the fixture
31- enable_torch_tf_np_mode = enp .testing .enable_torch_tf_np_mode
31+ enable_tf_np_mode = enp .testing .set_tnp
3232
3333# TODO(epot): Test dtype `complex`, `str`
3434
Original file line number Diff line number Diff line change 2525
2626import dataclass_array as dca
2727from etils import enp
28- import pytest
2928
3029
3130@dataclasses .dataclass (frozen = True )
Original file line number Diff line number Diff line change 2323import pytest
2424
2525# Activate the fixture
26- enable_torch_tf_np_mode = enp .testing .enable_torch_tf_np_mode
26+ enable_tf_np_mode = enp .testing .set_tnp
2727
2828
2929@enp .testing .parametrize_xnp ()
Original file line number Diff line number Diff line change 3434X1 = 5
3535
3636# Activate the fixture
37- enable_torch_tf_np_mode = enp .testing .enable_torch_tf_np_mode
37+ enable_tf_np_mode = enp .testing .set_tnp
3838
3939
4040@pytest .mark .parametrize (
You can’t perform that action at this time.
0 commit comments