Skip to content

Commit c64d854

Browse files
committed
feature(stable-ts): add word_tags for styling #154
1 parent 574e0d5 commit c64d854

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

src/subsai/models/stable_ts_model.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ class StableTsModel(AbstractModel):
6363
'options': None,
6464
'default': None
6565
},
66+
'word_timestamps': {
67+
'type': bool,
68+
'description': 'Extract word-level timestamps using the cross-attention pattern'
69+
'and dynamic time warping, and include the timestamps for each word in each segment.',
70+
'options': None,
71+
'default': False
72+
},
73+
'word_tags': {
74+
'type': str,
75+
'description': 'When word_timestamps is True, you can use this option to wrap each word with specified tags for styling, e.g., <font color="#FFB600">,</font>. '
76+
'Separate open and close tags with a comma. Leave empty to display words on separate lines.',
77+
'options': None,
78+
'default': '<font color="#FFB600">,</font>'
79+
},
6680
'temperature': {
6781
'type': Tuple,
6882
'description': "Temperature for sampling. It can be a tuple of temperatures, which will be "
@@ -105,13 +119,6 @@ class StableTsModel(AbstractModel):
105119
'options': None,
106120
'default': None
107121
},
108-
'word_timestamps': {
109-
'type': bool,
110-
'description': 'Extract word-level timestamps using the cross-attention pattern'
111-
'and dynamic time warping, and include the timestamps for each word in each segment.',
112-
'options': None,
113-
'default': True
114-
},
115122
'regroup': {
116123
'type': bool,
117124
'description': "default True, meaning the default regroup algorithm"
@@ -410,6 +417,7 @@ def __init__(self, model_config):
410417
self._condition_on_previous_text = _load_config('condition_on_previous_text', model_config, self.config_schema)
411418
self._initial_prompt = _load_config('initial_prompt', model_config, self.config_schema)
412419
self._word_timestamps = _load_config('word_timestamps', model_config, self.config_schema)
420+
self._word_tags = _load_config('word_tags', model_config, self.config_schema)
413421
self._regroup = _load_config('regroup', model_config, self.config_schema)
414422
self._ts_num = _load_config('ts_num', model_config, self.config_schema)
415423
self._ts_noise = _load_config('ts_noise', model_config, self.config_schema)
@@ -474,7 +482,7 @@ def transcribe(self, media_file) -> SSAFile:
474482
k_size=self._k_size,
475483
time_scale=self._time_scale,
476484
demucs=self._demucs,
477-
demucs_output=self._demucs_output,
485+
# demucs_output=self._demucs_output,
478486
demucs_options=self._demucs_options,
479487
vad=self._vad,
480488
vad_threshold=self._vad_threshold,
@@ -500,7 +508,12 @@ def transcribe(self, media_file) -> SSAFile:
500508
for word in segment.words:
501509
try:
502510
event = SSAEvent(start=pysubs2.make_time(s=word.start), end=pysubs2.make_time(s=word.end))
503-
event.plaintext = word.word.strip()
511+
print(f"word tags: {self._word_tags}")
512+
if self._word_tags != '' and self._word_tags is not None:
513+
opening_tag, closing_tag = self._word_tags.split(',')
514+
event.plaintext = segment.text.replace(word.word, f'{opening_tag}{word.word}{closing_tag}')
515+
else:
516+
event.plaintext = word.word.strip()
504517
subs.append(event)
505518
except Exception as e:
506519
logging.warning(f"Something wrong with {word}")

0 commit comments

Comments
 (0)