Skip to content

Commit af3ceae

Browse files
committed
Avoid tensorflow in CI tests (too large dep)
1 parent 571f1cf commit af3ceae

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

requirements-tests.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ pytest
44
psutil
55
msgpack
66
torch
7-
tensorflow
7+
#tensorflow # too large dependency; torch should be enough for testing

tests/test_tensor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
import blosc2
1515
import numpy as np
1616

17-
try:
18-
import tensorflow as tf
19-
import torch
20-
except ImportError:
21-
pytest.skip("skipping torch / tensorflow tests", allow_module_level=True)
22-
2317

2418
##### pack / unpack #####
2519

@@ -73,11 +67,13 @@ def test_pack_array2_struct(size, dtype):
7367
@pytest.mark.parametrize(
7468
"size, dtype",
7569
[
76-
(1e6, torch.float32),
77-
(1e6, torch.float64),
78-
(1e6, torch.int8),
70+
(1e6, "float32"),
71+
(1e6, "float64"),
72+
(1e6, "int8"),
7973
])
8074
def test_pack_tensor_torch(size, dtype):
75+
torch = pytest.importorskip("torch")
76+
dtype = getattr(torch, dtype)
8177
tensor = torch.arange(size, dtype=dtype)
8278
cframe = blosc2.pack_tensor(tensor)
8379
atensor = np.asarray(tensor)
@@ -95,8 +91,9 @@ def test_pack_tensor_torch(size, dtype):
9591
(1e6, np.int8),
9692
])
9793
def test_pack_tensor_tensorflow(size, dtype):
94+
tensorflow = pytest.importorskip("tensorflow")
9895
array = np.arange(size, dtype=dtype)
99-
tensor = tf.constant(array)
96+
tensor = tensorflow.constant(array)
10097
cframe = blosc2.pack_tensor(tensor)
10198
atensor = np.asarray(tensor)
10299
assert len(cframe) < atensor.size * atensor.dtype.itemsize
@@ -164,8 +161,9 @@ def test_save_tensor_array(size, dtype, urlpath):
164161
(1e6, "float32", "test.bl2"),
165162
])
166163
def test_save_tensor_tensorflow(size, dtype, urlpath):
164+
tensorflow = pytest.importorskip("tensorflow")
167165
nparray = np.arange(size, dtype=dtype)
168-
tensor = tf.constant(nparray)
166+
tensor = tensorflow.constant(nparray)
169167
serial_size = blosc2.save_tensor(tensor, urlpath, mode="w")
170168
assert serial_size < nparray.size * nparray.itemsize
171169

@@ -181,6 +179,7 @@ def test_save_tensor_tensorflow(size, dtype, urlpath):
181179
(1e6, "float32", "test.bl2"),
182180
])
183181
def test_save_tensor_torch(size, dtype, urlpath):
182+
torch = pytest.importorskip("torch")
184183
nparray = np.arange(size, dtype=dtype)
185184
tensor = torch.tensor(nparray)
186185
serial_size = blosc2.save_tensor(tensor, urlpath, mode="w")

0 commit comments

Comments
 (0)