Skip to content

Commit 6ae9c28

Browse files
committed
Deprecate jax and tensorflow, remove unused files
1 parent c0aeb4a commit 6ae9c28

File tree

8 files changed

+42
-157
lines changed

8 files changed

+42
-157
lines changed

outlines/models/transformers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,29 @@ def __init__(
244244
and isinstance(model, FlaxPreTrainedModel)
245245
):
246246
self.tensor_library_name = "jax"
247+
warnings.warn("""
248+
Support for `jax` has been deprecated and will be removed in
249+
version 1.4.0 of Outlines. Please use `torch` instead.
250+
Transformers models using `jax` do not support structured
251+
generation.
252+
""",
253+
DeprecationWarning,
254+
stacklevel=2,
255+
)
247256
elif (
248257
TFPreTrainedModel is not None
249258
and isinstance(model, TFPreTrainedModel)
250259
):
251260
self.tensor_library_name = "tensorflow"
261+
warnings.warn("""
262+
Support for `tensorflow` has been deprecated and will be removed in
263+
version 1.4.0 of Outlines. Please use `torch` instead.
264+
Transformers models using `tensorflow` do not support structured
265+
generation.
266+
""",
267+
DeprecationWarning,
268+
stacklevel=2,
269+
)
252270
else:
253271
self.tensor_library_name = "torch"
254272

outlines/processors/base_logits_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self, tensor_library_name: str):
2828
----------
2929
tensor_library_name
3030
The name of the library to use to manipulate tensors. Possible
31-
values are "jax", "mlx", "numpy", "tensorflow" and "torch". You
32-
must choose the library that your model is using.
31+
values are "mlx", "numpy" and "torch". You must choose the library
32+
that your model is using.
3333
"""
3434
# Temporary fix as torch raises a warning that can cause can an error
3535
# with python 3.12.
@@ -52,7 +52,7 @@ def reset(self):
5252
needs to be reset for a new generation.
5353
5454
"""
55-
pass
55+
pass # pragma: no cover
5656

5757
@abstractmethod
5858
def process_logits(

outlines/processors/tensor_adapters/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,19 @@
22

33
from typing import Union
44

5-
from .jax import JAXTensorAdapter
65
from .mlx import MLXTensorAdapter
76
from .numpy import NumpyTensorAdapter
8-
from .tensorflow import TensorFlowTensorAdapter
97
from .torch import TorchTensorAdapter
108

119

1210
tensor_adapters = {
13-
"jax": JAXTensorAdapter,
1411
"mlx": MLXTensorAdapter,
1512
"numpy": NumpyTensorAdapter,
16-
"tensorflow": TensorFlowTensorAdapter,
1713
"torch": TorchTensorAdapter,
1814
}
1915

2016
TensorAdapterImplementation = Union[
21-
JAXTensorAdapter,
2217
MLXTensorAdapter,
2318
NumpyTensorAdapter,
24-
TensorFlowTensorAdapter,
2519
TorchTensorAdapter,
2620
]

outlines/processors/tensor_adapters/jax.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

outlines/processors/tensor_adapters/tensorflow.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/models/test_transformers.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,27 @@ def test_transformers_instantiate_simple():
4040

4141

4242
def test_transformers_instantiate_flax_model():
43-
model = outlines.from_transformers(
44-
transformers.FlaxAutoModelForCausalLM.from_pretrained(TEST_MODEL),
45-
transformers.AutoTokenizer.from_pretrained(TEST_MODEL),
46-
)
47-
assert isinstance(model, Transformers)
48-
assert isinstance(model.tokenizer, TransformerTokenizer)
49-
assert isinstance(model.type_adapter, TransformersTypeAdapter)
50-
assert model.tensor_library_name == "jax"
43+
with pytest.warns(DeprecationWarning):
44+
model = outlines.from_transformers(
45+
transformers.FlaxAutoModelForCausalLM.from_pretrained(TEST_MODEL),
46+
transformers.AutoTokenizer.from_pretrained(TEST_MODEL),
47+
)
48+
assert isinstance(model, Transformers)
49+
assert isinstance(model.tokenizer, TransformerTokenizer)
50+
assert isinstance(model.type_adapter, TransformersTypeAdapter)
51+
assert model.tensor_library_name == "jax"
5152

5253

5354
def test_transformers_instantiate_tensorflow_model():
54-
model = outlines.from_transformers(
55-
transformers.TFAutoModelForCausalLM.from_pretrained(TEST_MODEL),
56-
transformers.AutoTokenizer.from_pretrained(TEST_MODEL),
57-
)
58-
assert isinstance(model, Transformers)
59-
assert isinstance(model.tokenizer, TransformerTokenizer)
60-
assert isinstance(model.type_adapter, TransformersTypeAdapter)
61-
assert model.tensor_library_name == "tensorflow"
55+
with pytest.warns(DeprecationWarning):
56+
model = outlines.from_transformers(
57+
transformers.TFAutoModelForCausalLM.from_pretrained(TEST_MODEL),
58+
transformers.AutoTokenizer.from_pretrained(TEST_MODEL),
59+
)
60+
assert isinstance(model, Transformers)
61+
assert isinstance(model.tokenizer, TransformerTokenizer)
62+
assert isinstance(model.type_adapter, TransformersTypeAdapter)
63+
assert model.tensor_library_name == "tensorflow"
6264

6365

6466
def test_transformers_instantiate_mamba():

tests/processors/test_base_processor.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import List
22

3-
import jax.numpy as jnp
43
import numpy as np
54
import pytest
65
import torch
@@ -14,7 +13,7 @@
1413
HAS_MLX = False
1514

1615

17-
libraries = ["numpy", "torch", "jax"]
16+
libraries = ["numpy", "torch"]
1817
if HAS_MLX:
1918
libraries.append("mlx")
2019

@@ -43,14 +42,6 @@
4342
(torch.tensor([1, 2], dtype=torch.float32), torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), AssertionError),
4443
(torch.tensor([[[1, 2]]], dtype=torch.float32), torch.tensor([[[1, 2]]], dtype=torch.int32), ValueError),
4544
],
46-
"jax": [
47-
(jnp.array([1, 2], dtype=jnp.float32), jnp.array([1, 2], dtype=jnp.int32), None),
48-
(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32), jnp.array([[1, 2], [3, 4]], dtype=jnp.int32), None),
49-
(jnp.array([1, 2], dtype=jnp.float32), jnp.array([[1, 2]], dtype=jnp.int32), None),
50-
(jnp.array([[1, 2]], dtype=jnp.float32), jnp.array([1, 2], dtype=jnp.int32), AssertionError),
51-
(jnp.array([1, 2], dtype=jnp.float32), jnp.array([[1, 2], [3, 4]], dtype=jnp.int32), AssertionError),
52-
(jnp.array([[[1, 2]]], dtype=jnp.float32), jnp.array([[[1, 2]]], dtype=jnp.int32), ValueError),
53-
],
5445
}
5546
if HAS_MLX:
5647
arrays["mlx"] = [
@@ -76,6 +67,7 @@ def test_base_logits_processor_init(library):
7667
assert processor.tensor_adapter is not None
7768
with pytest.raises(NotImplementedError):
7869
processor = MockLogitsProcessor("foo")
70+
processor.reset()
7971

8072

8173
@pytest.mark.parametrize("library", libraries)

tests/processors/test_tensor_adapters.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
import pytest
22
from pytest import mark
33

4-
import jax
5-
import jax.numpy as jnp
64
import numpy as np
7-
import tensorflow as tf
85
import torch
96

107
from outlines.processors.tensor_adapters import (
118
NumpyTensorAdapter,
129
TorchTensorAdapter,
1310
MLXTensorAdapter,
14-
JAXTensorAdapter,
15-
TensorFlowTensorAdapter,
1611
)
1712

1813
try:
@@ -25,26 +20,19 @@
2520

2621

2722
adapters = {
28-
"jax": JAXTensorAdapter(),
2923
"numpy": NumpyTensorAdapter(),
30-
"tensorflow": TensorFlowTensorAdapter(),
3124
"torch": TorchTensorAdapter(),
3225
}
3326
if HAS_MLX:
3427
adapters["mlx"] = MLXTensorAdapter()
3528

36-
frameworks = ["jax", "numpy", "tensorflow", "torch", "mlx"]
29+
frameworks = ["numpy", "torch", "mlx"]
3730

3831
def create_tensor(framework, shape, dtype=None):
3932
if framework == "torch":
4033
return torch.randn(*shape)
4134
elif framework == "numpy":
4235
return np.random.randn(*shape)
43-
elif framework == "jax":
44-
key = jax.random.PRNGKey(0)
45-
return jax.random.normal(key, shape=shape)
46-
elif framework == "tensorflow":
47-
return tf.random.normal(shape)
4836
elif framework == "mlx":
4937
if not HAS_MLX:
5038
pytest.skip("MLX not available")
@@ -55,10 +43,6 @@ def compare_tensors(framework, tensor1, tensor2):
5543
return torch.allclose(tensor1, tensor2)
5644
elif framework == "numpy":
5745
return np.array_equal(tensor1, tensor2)
58-
elif framework == "jax":
59-
return jax.numpy.array_equal(tensor1, tensor2)
60-
elif framework == "tensorflow":
61-
return tf.reduce_all(tf.equal(tensor1, tensor2))
6246
elif framework == "mlx":
6347
if not HAS_MLX:
6448
pytest.skip("MLX not available")
@@ -243,11 +227,6 @@ def test_tensor_adapter_apply_mask(framework):
243227
mask = torch.randn(2, 3) > 0
244228
elif framework == "numpy":
245229
mask = np.random.randn(2, 3) > 0
246-
elif framework == "jax":
247-
key = jax.random.PRNGKey(0)
248-
mask = jax.random.normal(key, shape=(2, 3)) > 0
249-
elif framework == "tensorflow":
250-
mask = tf.random.normal((2, 3)) > 0
251230
elif framework == "mlx":
252231
if not HAS_MLX:
253232
pytest.skip("MLX not available")

0 commit comments

Comments
 (0)