@@ -26,13 +26,23 @@ def make_rule(self, name: str, build_func: Callable[[Union[NativeWFST, WFST]], N
2626 assert rule .loaded
2727 return rule
2828
29- def decode (self , text : str , kaldi_rules_activity : list [bool ], expected_rule : Optional [KaldiRule ], expected_words_are_dictation_mask : Optional [list [bool ]] = None ):
30- audio_data = self .audio_generator (text )
29+ def decode (self , text_or_audio : Union [str , bytes ], kaldi_rules_activity : list [bool ], expected_rule : Optional [KaldiRule ], expected_words : Optional [list [str ]] = None , expected_words_are_dictation_mask : Optional [list [bool ]] = None ):
30+ if isinstance (text_or_audio , str ):
31+ text = text_or_audio
32+ audio_data = self .audio_generator (text )
33+ if expected_words is None :
34+ expected_words = text .split () if text else []
35+ else :
36+ text = None
37+ audio_data = text_or_audio
38+ if expected_words is None :
39+ expected_words = []
40+
3141 self .decoder .decode (audio_data , True , kaldi_rules_activity )
3242
3343 output , info = self .decoder .get_output ()
3444 assert isinstance (output , str )
35- assert len (output ) > 0 or text == ""
45+ assert len (output ) > 0 or expected_words == []
3646 assert_info_shape (info )
3747
3848 recognized_rule , words , words_are_dictation_mask = self .compiler .parse_output (output )
@@ -42,7 +52,7 @@ def decode(self, text: str, kaldi_rules_activity: list[bool], expected_rule: Opt
4252 assert words_are_dictation_mask == []
4353 else :
4454 assert recognized_rule == expected_rule
45- assert words == text . split ()
55+ assert words == expected_words
4656 if expected_words_are_dictation_mask is None :
4757 expected_words_are_dictation_mask = [False ] * len (words )
4858 assert words_are_dictation_mask == expected_words_are_dictation_mask
@@ -470,16 +480,7 @@ def _build(fst):
470480
471481 random .seed (42 )
472482 audio_data = bytes (random .randint (0 , 255 ) for _ in range (32768 ))
473- self .decoder .decode (audio_data , True , [True ])
474-
475- output , info = self .decoder .get_output ()
476- assert isinstance (output , str )
477- assert_info_shape (info )
478-
479- recognized_rule , words , words_are_dictation_mask = self .compiler .parse_output (output )
480- assert recognized_rule is None
481- assert words == []
482- assert words_are_dictation_mask == []
483+ self .decode (audio_data , [True ], None )
483484
484485 def test_empty_audio (self ):
485486 """Test decoder with empty audio data."""
@@ -488,18 +489,7 @@ def _build(fst):
488489 final_state = fst .add_state (final = True )
489490 fst .add_arc (initial_state , final_state , 'hello' )
490491 rule = self .make_rule ('EmptyAudioRule' , _build )
491-
492- self .decoder .decode (b'' , True , [True ])
493-
494- output , info = self .decoder .get_output ()
495- assert isinstance (output , str )
496- assert output == ""
497- assert_info_shape (info )
498-
499- recognized_rule , words , words_are_dictation_mask = self .compiler .parse_output (output )
500- assert recognized_rule is None
501- assert words == []
502- assert words_are_dictation_mask == []
492+ self .decode (b'' , [True ], None )
503493
504494 def test_very_short_audio (self ):
505495 """Test decoder with very short utterance."""
0 commit comments