Skip to content

Commit ef69ab8

Browse files
authored
Split long sentences on commas to prevent skipped words (#143)
## Summary When a single sentence exceeds `max_tokens` (50) and has no sentence-ending punctuation (`.`, `!`, `?`), the text splitting logic now falls back to splitting on commas, semicolons, and colons. This prevents the model from silently skipping parts of long sentences. Fixes #38
1 parent 670e777 commit ef69ab8

File tree

3 files changed

+194
-20
lines changed

3 files changed

+194
-20
lines changed

pocket_tts/conditioners/text.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def __call__(self, text: str) -> TokenizedText:
3535
return TokenizedText(torch.tensor(self.sp.encode(text, out_type=int))[None, :])
3636

3737

38+
DEFAULT_TOKENIZER_N_BINS = 4000
39+
DEFAULT_TOKENIZER_PATH = (
40+
"hf://kyutai/pocket-tts-without-voice-cloning/"
41+
"tokenizer.model@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3"
42+
)
43+
44+
45+
def get_default_tokenizer() -> SentencePieceTokenizer:
46+
"""Return a SentencePieceTokenizer with the default model path and vocab size.
47+
48+
Downloads the tokenizer model from HuggingFace on first use.
49+
"""
50+
return SentencePieceTokenizer(DEFAULT_TOKENIZER_N_BINS, DEFAULT_TOKENIZER_PATH)
51+
52+
3853
class LUTConditioner(BaseConditioner):
3954
"""Lookup table TextConditioner.
4055

pocket_tts/models/tts_model.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -838,39 +838,71 @@ def prepare_text_prompt(text: str) -> tuple[str, int]:
838838
return text, frames_after_eos_guess
839839

840840

841+
def _find_boundary_indices(list_of_tokens: list[int], boundary_tokens: list[int]) -> list[int]:
842+
"""Find token indices where text should be split based on boundary tokens.
843+
844+
Returns a list of boundary positions used to slice segments. Each consecutive
845+
pair (indices[i], indices[i+1]) defines one segment. The first element is
846+
always 0 and the last is always len(list_of_tokens).
847+
"""
848+
indices = [0]
849+
previous_was_boundary = False
850+
for idx, token in enumerate(list_of_tokens):
851+
if token in boundary_tokens:
852+
previous_was_boundary = True
853+
else:
854+
if previous_was_boundary:
855+
indices.append(idx)
856+
previous_was_boundary = False
857+
indices.append(len(list_of_tokens))
858+
return indices
859+
860+
861+
def _segments_from_boundaries(
862+
list_of_tokens: list[int], boundary_indices: list[int], tokenizer
863+
) -> list[tuple[int, str]]:
864+
"""Decode token segments between boundary indices into (token_count, text) pairs."""
865+
segments = []
866+
for i in range(len(boundary_indices) - 1):
867+
start = boundary_indices[i]
868+
end = boundary_indices[i + 1]
869+
text = tokenizer.sp.decode(list_of_tokens[start:end])
870+
segments.append((end - start, text))
871+
return segments
872+
873+
841874
def split_into_best_sentences(tokenizer, text_to_generate: str, max_tokens: int) -> list[str]:
842875
text_to_generate, _ = prepare_text_prompt(text_to_generate)
843876
text_to_generate = text_to_generate.strip()
844877
tokens = tokenizer(text_to_generate)
845878
list_of_tokens = tokens.tokens[0].tolist()
846879

847880
_, *end_of_sentence_tokens = tokenizer(".!...?").tokens[0].tolist()
848-
849-
end_of_sentences_indices = [0]
850-
previous_was_end_of_sentence_token = False
851-
852-
for token_idx, token in enumerate(list_of_tokens):
853-
if token in end_of_sentence_tokens:
854-
previous_was_end_of_sentence_token = True
881+
sentence_boundaries = _find_boundary_indices(list_of_tokens, end_of_sentence_tokens)
882+
nb_tokens_and_sentences = _segments_from_boundaries(
883+
list_of_tokens, sentence_boundaries, tokenizer
884+
)
885+
886+
# Sub-split oversized sentences on commas, semicolons, and colons to prevent skipped words
887+
_, *fallback_tokens = tokenizer(",;:").tokens[0].tolist()
888+
refined_segments = []
889+
for nb_tokens, text in nb_tokens_and_sentences:
890+
if nb_tokens <= max_tokens:
891+
refined_segments.append((nb_tokens, text))
855892
else:
856-
if previous_was_end_of_sentence_token:
857-
end_of_sentences_indices.append(token_idx)
858-
previous_was_end_of_sentence_token = False
859-
end_of_sentences_indices.append(len(list_of_tokens))
860-
861-
nb_tokens_and_sentences = []
862-
for i in range(len(end_of_sentences_indices) - 1):
863-
# let's print
864-
start = end_of_sentences_indices[i]
865-
end = end_of_sentences_indices[i + 1]
866-
text = tokenizer.sp.decode(list_of_tokens[start:end])
867-
nb_tokens_and_sentences.append((end - start, text))
893+
sub_tokens = tokenizer(text.strip()).tokens[0].tolist()
894+
sub_boundaries = _find_boundary_indices(sub_tokens, fallback_tokens)
895+
sub_segments = _segments_from_boundaries(sub_tokens, sub_boundaries, tokenizer)
896+
if len(sub_segments) > 1:
897+
refined_segments.extend(sub_segments)
898+
else:
899+
refined_segments.append((nb_tokens, text))
868900

869901
max_nb_tokens_in_a_chunk = max_tokens
870902
chunks = []
871903
current_chunk = ""
872904
current_nb_of_tokens_in_chunk = 0
873-
for nb_tokens, sentence in nb_tokens_and_sentences:
905+
for nb_tokens, sentence in refined_segments:
874906
if current_chunk == "":
875907
current_chunk = sentence
876908
current_nb_of_tokens_in_chunk = nb_tokens
@@ -887,6 +919,16 @@ def split_into_best_sentences(tokenizer, text_to_generate: str, max_tokens: int)
887919
if current_chunk != "":
888920
chunks.append(current_chunk.strip())
889921

922+
for chunk in chunks:
923+
chunk_tokens = tokenizer(chunk.strip()).tokens[0].tolist()
924+
if len(chunk_tokens) > max_tokens:
925+
logger.warning(
926+
"Chunk has %d tokens (max %d), generation may skip words: '%.50s...'",
927+
len(chunk_tokens),
928+
max_tokens,
929+
chunk,
930+
)
931+
890932
return chunks
891933

892934

tests/test_split_sentences.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests for the text splitting logic in split_into_best_sentences."""
2+
3+
import pytest
4+
5+
from pocket_tts.conditioners.text import get_default_tokenizer
6+
from pocket_tts.models.tts_model import split_into_best_sentences
7+
8+
9+
@pytest.fixture(scope="session")
10+
def tokenizer():
11+
return get_default_tokenizer()
12+
13+
14+
def test_short_text_single_chunk(tokenizer):
15+
"""Short text should produce a single chunk."""
16+
chunks = split_into_best_sentences(tokenizer, "Hello world.", 50)
17+
assert len(chunks) == 1
18+
19+
20+
def test_multiple_sentences_split(tokenizer):
21+
"""Multiple sentences should be split when they exceed max_tokens."""
22+
text = "First sentence here. Second sentence here. Third sentence here. Fourth sentence here."
23+
chunks = split_into_best_sentences(tokenizer, text, 10)
24+
assert len(chunks) > 1
25+
26+
27+
def test_long_sentence_with_commas_is_split(tokenizer):
28+
"""A long sentence with only commas (no periods) should be split on commas."""
29+
# This is the core bug from issue #38 - the Tale of Two Cities example
30+
text = (
31+
"It was the best of times, it was the worst of times, "
32+
"it was the age of wisdom, it was the age of foolishness, "
33+
"it was the epoch of belief, it was the epoch of incredulity, "
34+
"it was the season of Light, it was the season of Darkness, "
35+
"it was the spring of hope, it was the winter of despair"
36+
)
37+
chunks = split_into_best_sentences(tokenizer, text, 50)
38+
assert len(chunks) > 1, "Long comma-separated text should be split into multiple chunks"
39+
40+
# Verify all content is preserved (no words should be lost in splitting)
41+
rejoined = " ".join(chunks).lower()
42+
for phrase in ["best of times", "worst of times", "age of foolishness", "winter of despair"]:
43+
assert phrase in rejoined, f"'{phrase}' should be preserved after splitting"
44+
45+
46+
def test_long_sentence_with_commas_respects_max_tokens(tokenizer):
47+
"""Each chunk from comma splitting should respect max_tokens (when possible)."""
48+
text = (
49+
"It was the best of times, it was the worst of times, "
50+
"it was the age of wisdom, it was the age of foolishness, "
51+
"it was the epoch of belief, it was the epoch of incredulity"
52+
)
53+
max_tokens = 20
54+
chunks = split_into_best_sentences(tokenizer, text, max_tokens)
55+
for chunk in chunks:
56+
token_count = len(tokenizer(chunk.strip()).tokens[0].tolist())
57+
# Allow some tolerance since comma clauses may vary in size
58+
assert token_count <= max_tokens * 2, (
59+
f"Chunk '{chunk[:50]}...' has {token_count} tokens, expected ~{max_tokens}"
60+
)
61+
62+
63+
def test_mixed_sentences_and_commas(tokenizer):
64+
"""Text with both sentence boundaries and long comma-separated clauses."""
65+
text = (
66+
"Short sentence. "
67+
"This is a very long sentence with many clauses, separated by commas, "
68+
"that goes on and on, and on some more, without any periods at all, "
69+
"until it finally reaches a period. "
70+
"Another short one."
71+
)
72+
chunks = split_into_best_sentences(tokenizer, text, 20)
73+
assert len(chunks) >= 3
74+
75+
76+
def test_no_commas_no_periods_stays_single_chunk(tokenizer):
77+
"""Text with no splitting characters stays as a single chunk."""
78+
text = "one two three four five six seven eight nine ten eleven twelve"
79+
chunks = split_into_best_sentences(tokenizer, text, 5)
80+
# Should be 1 chunk since there are no split points
81+
assert len(chunks) == 1
82+
83+
84+
def test_semicolons_and_colons_also_split(tokenizer):
85+
"""Semicolons and colons should also serve as fallback split points."""
86+
text = (
87+
"First clause here; second clause here; third clause here; "
88+
"fourth clause here; fifth clause here; sixth clause here"
89+
)
90+
chunks = split_into_best_sentences(tokenizer, text, 15)
91+
assert len(chunks) > 1
92+
93+
94+
def test_short_sentence_not_affected_by_comma_splitting(tokenizer):
95+
"""Sentences under max_tokens should not be affected by comma logic."""
96+
text = "Hello, world."
97+
chunks = split_into_best_sentences(tokenizer, text, 50)
98+
assert len(chunks) == 1
99+
assert "hello" in chunks[0].lower()
100+
assert "world" in chunks[0].lower()
101+
102+
103+
def test_empty_string_raises(tokenizer):
104+
"""Empty input should raise ValueError from prepare_text_prompt."""
105+
with pytest.raises(ValueError, match="empty"):
106+
split_into_best_sentences(tokenizer, "", 50)
107+
108+
109+
def test_oversized_clause_without_commas_still_returns(tokenizer):
110+
"""A long clause with no split points should still be returned (not dropped)."""
111+
# 20 words with no punctuation at all - no way to split
112+
text = " ".join(f"word{i}" for i in range(20))
113+
chunks = split_into_best_sentences(tokenizer, text, 5)
114+
assert len(chunks) == 1
115+
# prepare_text_prompt capitalizes the first char and adds a trailing period,
116+
# so compare case-insensitively and strip punctuation
117+
assert chunks[0].lower().rstrip(".") == text.lower()

0 commit comments

Comments
 (0)