Skip to content

Commit fe8a32d

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Delete mock_torch
PiperOrigin-RevId: 515343503
1 parent 845f715 commit fe8a32d

File tree

7 files changed

+4
-20
lines changed

7 files changed

+4
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26-
* Add `torch` support (experimental). Require to call
27-
`dca.activate_torch_support()`
26+
* **Add `torch` support!**
2827

2928
## [1.3.0] - 2023-01-16
3029

dataclass_array/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
from dataclass_array.ops import stack
3838
from dataclass_array.vectorization import vectorize_method
3939

40-
# TODO(epot): Remove once Torch has better numpy API
41-
from etils.enp import activate_torch_support
42-
4340
# `dca.testing` do not depend on pytest or other heavy deps, so is safe to
4441
# import
4542
from dataclass_array import testing

dataclass_array/array_dataclass.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -541,17 +541,6 @@ def _get_xnp(f: _ArrayField) -> enp.NpModule:
541541
return None
542542
xnp = _infer_xnp(xnps)
543543

544-
if (
545-
enp.lazy.has_torch
546-
and xnp is enp.lazy.torch
547-
and not hasattr(enp.lazy.torch, '__etils_np_mode__')
548-
):
549-
raise ValueError(
550-
'torch support currently require to call:\n'
551-
'import dataclass_array as dca\n'
552-
'dca.activate_torch_support()'
553-
)
554-
555544
def _cast_field(f: _ArrayField) -> None:
556545
try:
557546
new_value = np_utils.asarray(

dataclass_array/array_dataclass_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tensorflow as tf
2929

3030
# Activate the fixture
31-
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
31+
enable_tf_np_mode = enp.testing.set_tnp
3232

3333
# TODO(epot): Test dtype `complex`, `str`
3434

dataclass_array/import_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import dataclass_array as dca
2727
from etils import enp
28-
import pytest
2928

3029

3130
@dataclasses.dataclass(frozen=True)

dataclass_array/utils/np_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pytest
2424

2525
# Activate the fixture
26-
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
26+
enable_tf_np_mode = enp.testing.set_tnp
2727

2828

2929
@enp.testing.parametrize_xnp()

dataclass_array/vectorization_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
X1 = 5
3535

3636
# Activate the fixture
37-
enable_torch_tf_np_mode = enp.testing.enable_torch_tf_np_mode
37+
enable_tf_np_mode = enp.testing.set_tnp
3838

3939

4040
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)