Skip to content

Commit df0f329

Browse files
committed
Tests: Refactor grammar tests
1 parent aa1a26c commit df0f329

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

tests/test_grammar.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,23 @@ def make_rule(self, name: str, build_func: Callable[[Union[NativeWFST, WFST]], N
2626
assert rule.loaded
2727
return rule
2828

29-
def decode(self, text: str, kaldi_rules_activity: list[bool], expected_rule: Optional[KaldiRule], expected_words_are_dictation_mask: Optional[list[bool]] = None):
30-
audio_data = self.audio_generator(text)
29+
def decode(self, text_or_audio: Union[str, bytes], kaldi_rules_activity: list[bool], expected_rule: Optional[KaldiRule], expected_words: Optional[list[str]] = None, expected_words_are_dictation_mask: Optional[list[bool]] = None):
30+
if isinstance(text_or_audio, str):
31+
text = text_or_audio
32+
audio_data = self.audio_generator(text)
33+
if expected_words is None:
34+
expected_words = text.split() if text else []
35+
else:
36+
text = None
37+
audio_data = text_or_audio
38+
if expected_words is None:
39+
expected_words = []
40+
3141
self.decoder.decode(audio_data, True, kaldi_rules_activity)
3242

3343
output, info = self.decoder.get_output()
3444
assert isinstance(output, str)
35-
assert len(output) > 0 or text == ""
45+
assert len(output) > 0 or expected_words == []
3646
assert_info_shape(info)
3747

3848
recognized_rule, words, words_are_dictation_mask = self.compiler.parse_output(output)
@@ -42,7 +52,7 @@ def decode(self, text: str, kaldi_rules_activity: list[bool], expected_rule: Opt
4252
assert words_are_dictation_mask == []
4353
else:
4454
assert recognized_rule == expected_rule
45-
assert words == text.split()
55+
assert words == expected_words
4656
if expected_words_are_dictation_mask is None:
4757
expected_words_are_dictation_mask = [False] * len(words)
4858
assert words_are_dictation_mask == expected_words_are_dictation_mask
@@ -470,16 +480,7 @@ def _build(fst):
470480

471481
random.seed(42)
472482
audio_data = bytes(random.randint(0, 255) for _ in range(32768))
473-
self.decoder.decode(audio_data, True, [True])
474-
475-
output, info = self.decoder.get_output()
476-
assert isinstance(output, str)
477-
assert_info_shape(info)
478-
479-
recognized_rule, words, words_are_dictation_mask = self.compiler.parse_output(output)
480-
assert recognized_rule is None
481-
assert words == []
482-
assert words_are_dictation_mask == []
483+
self.decode(audio_data, [True], None)
483484

484485
def test_empty_audio(self):
485486
"""Test decoder with empty audio data."""
@@ -488,18 +489,7 @@ def _build(fst):
488489
final_state = fst.add_state(final=True)
489490
fst.add_arc(initial_state, final_state, 'hello')
490491
rule = self.make_rule('EmptyAudioRule', _build)
491-
492-
self.decoder.decode(b'', True, [True])
493-
494-
output, info = self.decoder.get_output()
495-
assert isinstance(output, str)
496-
assert output == ""
497-
assert_info_shape(info)
498-
499-
recognized_rule, words, words_are_dictation_mask = self.compiler.parse_output(output)
500-
assert recognized_rule is None
501-
assert words == []
502-
assert words_are_dictation_mask == []
492+
self.decode(b'', [True], None)
503493

504494
def test_very_short_audio(self):
505495
"""Test decoder with very short utterance."""

0 commit comments

Comments
 (0)