@@ -616,14 +616,14 @@ def _default_config(self):
616616 "layer_types" : ["full_attention" ] * 28 ,
617617 }
618618
619- @patch ("transformers.LlamaTokenizer " )
619+ @patch ("transformers.AutoTokenizer " )
620620 def test_init (self , mock_tokenizer ):
621621 """Test LlamaModel initialization."""
622622 from mlx_audio .tts .models .llama .llama import Model , ModelConfig
623623
624624 # Mock the tokenizer instance
625625 mock_tokenizer_instance = MagicMock ()
626- mock_tokenizer .return_value = mock_tokenizer_instance
626+ mock_tokenizer .from_pretrained . return_value = mock_tokenizer_instance
627627
628628 # Create a minimal config
629629 config = ModelConfig (** self ._default_config )
@@ -634,14 +634,22 @@ def test_init(self, mock_tokenizer):
634634 # Check that model was created
635635 self .assertIsInstance (model , Model )
636636
637- @patch ("transformers.LlamaTokenizer " )
637+ @patch ("transformers.AutoTokenizer " )
638638 def test_generate (self , mock_tokenizer ):
639639 """Test generate method."""
640640 from mlx_audio .tts .models .llama .llama import Model , ModelConfig
641641
642642 # Mock tokenizer instance
643643 mock_tokenizer_instance = MagicMock ()
644- mock_tokenizer .return_value = mock_tokenizer_instance
644+
645+ def mock_tokenize (text , return_tensors = None ):
646+ result = MagicMock ()
647+ result .input_ids = mx .array ([[1 , 2 , 3 , 4 ]], dtype = mx .int64 )
648+ return result
649+
650+ mock_tokenizer_instance .side_effect = mock_tokenize
651+ mock_tokenizer_instance .__call__ = mock_tokenize
652+ mock_tokenizer .from_pretrained .return_value = mock_tokenizer_instance
645653
646654 config = ModelConfig (** self ._default_config )
647655 model = Model (config )
@@ -651,7 +659,7 @@ def test_generate(self, mock_tokenizer):
651659 self .assertEqual (input_ids .shape [0 ], 2 )
652660
653661 logits = model (input_ids )
654- self .assertEqual (logits .shape , (2 , 9 , config .vocab_size ))
662+ self .assertEqual (logits .shape , (2 , input_ids . shape [ 1 ] , config .vocab_size ))
655663
656664 # Verify batched input creation with reference audio
657665 input_ids , input_mask = model .prepare_input_ids (
@@ -660,16 +668,16 @@ def test_generate(self, mock_tokenizer):
660668 self .assertEqual (input_ids .shape [0 ], 2 )
661669
662670 logits = model (input_ids )
663- self .assertEqual (logits .shape , (2 , 22 , config .vocab_size ))
671+ self .assertEqual (logits .shape , (2 , input_ids . shape [ 1 ] , config .vocab_size ))
664672
665- @patch ("transformers.LlamaTokenizer " )
673+ @patch ("transformers.AutoTokenizer " )
666674 def test_sanitize (self , mock_tokenizer ):
667675 """Test sanitize method."""
668676 from mlx_audio .tts .models .llama .llama import Model , ModelConfig
669677
670678 # Mock tokenizer instance
671679 mock_tokenizer_instance = MagicMock ()
672- mock_tokenizer .return_value = mock_tokenizer_instance
680+ mock_tokenizer .from_pretrained . return_value = mock_tokenizer_instance
673681
674682 # Create a config with tie_word_embeddings=True
675683 config = ModelConfig (
@@ -918,14 +926,14 @@ def _default_config(self):
918926 "vocab_size" : 134400 ,
919927 }
920928
921- @patch ("transformers.LlamaTokenizer " )
929+ @patch ("transformers.AutoTokenizer " )
922930 def test_init (self , mock_tokenizer ):
923931 """Test initialization."""
924932 from mlx_audio .tts .models .outetts .outetts import Model , ModelConfig
925933
926934 # Mock the tokenizer instance
927935 mock_tokenizer_instance = MagicMock ()
928- mock_tokenizer .return_value = mock_tokenizer_instance
936+ mock_tokenizer .from_pretrained . return_value = mock_tokenizer_instance
929937
930938 # Create a minimal config
931939 config = ModelConfig (** self ._default_config )
@@ -936,14 +944,14 @@ def test_init(self, mock_tokenizer):
936944 # Check that model was created
937945 self .assertIsInstance (model , Model )
938946
939- @patch ("transformers.LlamaTokenizer " )
947+ @patch ("transformers.AutoTokenizer " )
940948 def test_generate (self , mock_tokenizer ):
941949 """Test generate method."""
942950 from mlx_audio .tts .models .outetts .outetts import Model , ModelConfig
943951
944952 # Mock tokenizer instance
945953 mock_tokenizer_instance = MagicMock ()
946- mock_tokenizer .return_value = mock_tokenizer_instance
954+ mock_tokenizer .from_pretrained . return_value = mock_tokenizer_instance
947955
948956 config = ModelConfig (** self ._default_config )
949957 model = Model (config )
0 commit comments