Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions mlx_audio/tts/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,14 @@ def _default_config(self):
"layer_types": ["full_attention"] * 28,
}

@patch("transformers.LlamaTokenizer")
@patch("transformers.AutoTokenizer")
def test_init(self, mock_tokenizer):
"""Test LlamaModel initialization."""
from mlx_audio.tts.models.llama.llama import Model, ModelConfig

# Mock the tokenizer instance
mock_tokenizer_instance = MagicMock()
mock_tokenizer.return_value = mock_tokenizer_instance
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

# Create a minimal config
config = ModelConfig(**self._default_config)
Expand All @@ -634,14 +634,22 @@ def test_init(self, mock_tokenizer):
# Check that model was created
self.assertIsInstance(model, Model)

@patch("transformers.LlamaTokenizer")
@patch("transformers.AutoTokenizer")
def test_generate(self, mock_tokenizer):
"""Test generate method."""
from mlx_audio.tts.models.llama.llama import Model, ModelConfig

# Mock tokenizer instance
mock_tokenizer_instance = MagicMock()
mock_tokenizer.return_value = mock_tokenizer_instance

def mock_tokenize(text, return_tensors=None):
result = MagicMock()
result.input_ids = mx.array([[1, 2, 3, 4]], dtype=mx.int64)
return result

mock_tokenizer_instance.side_effect = mock_tokenize
mock_tokenizer_instance.__call__ = mock_tokenize
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

config = ModelConfig(**self._default_config)
model = Model(config)
Expand All @@ -651,7 +659,7 @@ def test_generate(self, mock_tokenizer):
self.assertEqual(input_ids.shape[0], 2)

logits = model(input_ids)
self.assertEqual(logits.shape, (2, 9, config.vocab_size))
self.assertEqual(logits.shape, (2, input_ids.shape[1], config.vocab_size))

# Verify batched input creation with reference audio
input_ids, input_mask = model.prepare_input_ids(
Expand All @@ -660,16 +668,16 @@ def test_generate(self, mock_tokenizer):
self.assertEqual(input_ids.shape[0], 2)

logits = model(input_ids)
self.assertEqual(logits.shape, (2, 22, config.vocab_size))
self.assertEqual(logits.shape, (2, input_ids.shape[1], config.vocab_size))

@patch("transformers.LlamaTokenizer")
@patch("transformers.AutoTokenizer")
def test_sanitize(self, mock_tokenizer):
"""Test sanitize method."""
from mlx_audio.tts.models.llama.llama import Model, ModelConfig

# Mock tokenizer instance
mock_tokenizer_instance = MagicMock()
mock_tokenizer.return_value = mock_tokenizer_instance
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

# Create a config with tie_word_embeddings=True
config = ModelConfig(
Expand Down Expand Up @@ -918,14 +926,14 @@ def _default_config(self):
"vocab_size": 134400,
}

@patch("transformers.LlamaTokenizer")
@patch("transformers.AutoTokenizer")
def test_init(self, mock_tokenizer):
"""Test initialization."""
from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig

# Mock the tokenizer instance
mock_tokenizer_instance = MagicMock()
mock_tokenizer.return_value = mock_tokenizer_instance
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

# Create a minimal config
config = ModelConfig(**self._default_config)
Expand All @@ -936,14 +944,14 @@ def test_init(self, mock_tokenizer):
# Check that model was created
self.assertIsInstance(model, Model)

@patch("transformers.LlamaTokenizer")
@patch("transformers.AutoTokenizer")
def test_generate(self, mock_tokenizer):
"""Test generate method."""
from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig

# Mock tokenizer instance
mock_tokenizer_instance = MagicMock()
mock_tokenizer.return_value = mock_tokenizer_instance
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance

config = ModelConfig(**self._default_config)
model = Model(config)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ requires-python = ">=3.10"
dependencies = [
"mlx>=0.25.2",
"numpy>=1.26.4",
"huggingface_hub>=0.27.0",
"transformers==5.0.0rc3",
"mlx-lm==0.30.5",
"huggingface_hub>=1.0",
"transformers>=5.0.0",
"mlx-lm==0.31.1",
"tqdm>=4.67.1",
"sounddevice==0.5.3",
"miniaudio>=1.61",
Expand Down
Loading