Skip to content

Commit 6d4258c

Browse files
committed
Update transformers to stable 5.x
1 parent 3c874c6 commit 6d4258c

File tree

3 files changed

+2255
-2200
lines changed

3 files changed

+2255
-2200
lines changed

mlx_audio/tts/tests/test_models.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -616,14 +616,14 @@ def _default_config(self):
616616
"layer_types": ["full_attention"] * 28,
617617
}
618618

619-
@patch("transformers.LlamaTokenizer")
619+
@patch("transformers.AutoTokenizer")
620620
def test_init(self, mock_tokenizer):
621621
"""Test LlamaModel initialization."""
622622
from mlx_audio.tts.models.llama.llama import Model, ModelConfig
623623

624624
# Mock the tokenizer instance
625625
mock_tokenizer_instance = MagicMock()
626-
mock_tokenizer.return_value = mock_tokenizer_instance
626+
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
627627

628628
# Create a minimal config
629629
config = ModelConfig(**self._default_config)
@@ -634,14 +634,22 @@ def test_init(self, mock_tokenizer):
634634
# Check that model was created
635635
self.assertIsInstance(model, Model)
636636

637-
@patch("transformers.LlamaTokenizer")
637+
@patch("transformers.AutoTokenizer")
638638
def test_generate(self, mock_tokenizer):
639639
"""Test generate method."""
640640
from mlx_audio.tts.models.llama.llama import Model, ModelConfig
641641

642642
# Mock tokenizer instance
643643
mock_tokenizer_instance = MagicMock()
644-
mock_tokenizer.return_value = mock_tokenizer_instance
644+
645+
def mock_tokenize(text, return_tensors=None):
646+
result = MagicMock()
647+
result.input_ids = mx.array([[1, 2, 3, 4]], dtype=mx.int64)
648+
return result
649+
650+
mock_tokenizer_instance.side_effect = mock_tokenize
651+
mock_tokenizer_instance.__call__ = mock_tokenize
652+
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
645653

646654
config = ModelConfig(**self._default_config)
647655
model = Model(config)
@@ -651,7 +659,7 @@ def test_generate(self, mock_tokenizer):
651659
self.assertEqual(input_ids.shape[0], 2)
652660

653661
logits = model(input_ids)
654-
self.assertEqual(logits.shape, (2, 9, config.vocab_size))
662+
self.assertEqual(logits.shape, (2, input_ids.shape[1], config.vocab_size))
655663

656664
# Verify batched input creation with reference audio
657665
input_ids, input_mask = model.prepare_input_ids(
@@ -660,16 +668,16 @@ def test_generate(self, mock_tokenizer):
660668
self.assertEqual(input_ids.shape[0], 2)
661669

662670
logits = model(input_ids)
663-
self.assertEqual(logits.shape, (2, 22, config.vocab_size))
671+
self.assertEqual(logits.shape, (2, input_ids.shape[1], config.vocab_size))
664672

665-
@patch("transformers.LlamaTokenizer")
673+
@patch("transformers.AutoTokenizer")
666674
def test_sanitize(self, mock_tokenizer):
667675
"""Test sanitize method."""
668676
from mlx_audio.tts.models.llama.llama import Model, ModelConfig
669677

670678
# Mock tokenizer instance
671679
mock_tokenizer_instance = MagicMock()
672-
mock_tokenizer.return_value = mock_tokenizer_instance
680+
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
673681

674682
# Create a config with tie_word_embeddings=True
675683
config = ModelConfig(
@@ -918,14 +926,14 @@ def _default_config(self):
918926
"vocab_size": 134400,
919927
}
920928

921-
@patch("transformers.LlamaTokenizer")
929+
@patch("transformers.AutoTokenizer")
922930
def test_init(self, mock_tokenizer):
923931
"""Test initialization."""
924932
from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig
925933

926934
# Mock the tokenizer instance
927935
mock_tokenizer_instance = MagicMock()
928-
mock_tokenizer.return_value = mock_tokenizer_instance
936+
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
929937

930938
# Create a minimal config
931939
config = ModelConfig(**self._default_config)
@@ -936,14 +944,14 @@ def test_init(self, mock_tokenizer):
936944
# Check that model was created
937945
self.assertIsInstance(model, Model)
938946

939-
@patch("transformers.LlamaTokenizer")
947+
@patch("transformers.AutoTokenizer")
940948
def test_generate(self, mock_tokenizer):
941949
"""Test generate method."""
942950
from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig
943951

944952
# Mock tokenizer instance
945953
mock_tokenizer_instance = MagicMock()
946-
mock_tokenizer.return_value = mock_tokenizer_instance
954+
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
947955

948956
config = ModelConfig(**self._default_config)
949957
model = Model(config)

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ requires-python = ">=3.10"
2121
dependencies = [
2222
"mlx>=0.25.2",
2323
"numpy>=1.26.4",
24-
"huggingface_hub>=0.27.0",
25-
"transformers==5.0.0rc3",
26-
"mlx-lm==0.30.5",
24+
"huggingface_hub>=1.0",
25+
"transformers>=5.0.0",
26+
"mlx-lm==0.31.1",
2727
"tqdm>=4.67.1",
2828
"sounddevice==0.5.3",
2929
"miniaudio>=1.61",

0 commit comments

Comments
 (0)