|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import pathlib |
| 16 | +import re |
| 17 | +import unicodedata |
16 | 18 | from collections import defaultdict |
17 | 19 | from typing import Dict, List, Optional, Union |
18 | 20 |
|
| 21 | +import pyopenjtalk |
| 22 | + |
19 | 23 | from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import ( |
| 24 | + GRAPHEME_CHARACTER_SETS, |
20 | 25 | get_grapheme_character_set, |
21 | 26 | get_ipa_punctuation_list, |
22 | 27 | ) |
@@ -151,3 +156,193 @@ def __call__(self, text: str) -> List[str]: |
151 | 156 | logging.warning(f"{word} not found in the pronunciation dictionary. Returning graphemes instead.") |
152 | 157 | phoneme_seq += [c for c in word] |
153 | 158 | return phoneme_seq |
| 159 | + |
| 160 | + |
| 161 | +class JapaneseKatakanaAccentG2p(BaseG2p): |
| 162 | + """Japanese G2P module that converts text to Kana with pitch accent markers. |
| 163 | +
|
| 164 | + Converts Japanese text to katakana with pitch accent (0=low, 1=high) before each mora. |
| 165 | + Implements Japanese pitch accent rules for entire word chains. |
| 166 | +
|
| 167 | + Japanese pitch accent rules: |
| 168 | + - acc=0 (Heiban 平板): L-H-H-H... (first mora low, rest high) |
| 169 | + - acc=1 (Atamadaka 頭高): H-L-L-L... (first mora high, rest low) |
| 170 | + - acc=N (2 ≤ N < mora count, Nakadaka 中高): L-H-H...-H-L... (low, then high, drops after Nth mora) |
| 171 | + - acc=N (N >= mora count, Odaka 尾高): L-H-H-H (drop at end or after) |
| 172 | +
|
| 173 | + chain_flag handling: |
| 174 | + - chain_flag=0 or -1: Start of new word chain (use this word's acc) |
| 175 | + - chain_flag=1: Continuation of chain (ignore this word's acc) |
| 176 | + - Entire chain treated as single word with first word's acc and total mora count |
| 177 | +
|
| 178 | + Output format: [pitch, kana_char(s), pitch, kana_char(s), ...] where pitch (0/1) precedes each mora |
| 179 | + """ |
| 180 | + |
| 181 | + def __init__( |
| 182 | + self, |
| 183 | + ascii_letter_prefix: str = "", |
| 184 | + ascii_letter_case: str = "lower", |
| 185 | + word_tokenize_func=None, |
| 186 | + apply_to_oov_word=None, |
| 187 | + mapping_file: Optional[str] = None, |
| 188 | + ): |
| 189 | + if pyopenjtalk is None: |
| 190 | + raise ImportError("pyopenjtalk is required. Install with: pip install pyopenjtalk") |
| 191 | + if ascii_letter_prefix is None: |
| 192 | + ascii_letter_prefix = "" |
| 193 | + |
| 194 | + self.ascii_letter_case = ascii_letter_case |
| 195 | + # Load Japanese katakana grapheme set |
| 196 | + ja_graphemes = GRAPHEME_CHARACTER_SETS.get("ja-JP", []) |
| 197 | + |
| 198 | + pitch_markers = ['0', '1'] |
| 199 | + self.phoneme_list = sorted(list(ja_graphemes) + pitch_markers) |
| 200 | + |
| 201 | + # ASCII letters handling |
| 202 | + self.ascii_letter_dict = { |
| 203 | + x: ascii_letter_prefix + x for x in get_grapheme_character_set(locale="en-US", case=ascii_letter_case) |
| 204 | + } |
| 205 | + self.ascii_letter_list = sorted(self.ascii_letter_dict) |
| 206 | + |
| 207 | + self.punctuation = get_ipa_punctuation_list('ja-JP') |
| 208 | + |
| 209 | + super().__init__( |
| 210 | + word_tokenize_func=word_tokenize_func, |
| 211 | + apply_to_oov_word=apply_to_oov_word, |
| 212 | + mapping_file=mapping_file, |
| 213 | + ) |
| 214 | + |
| 215 | + @staticmethod |
| 216 | + def _split_katakana_to_moras(katakana: str) -> List[str]: |
| 217 | + """Split Mora pattern: [main_katakana][small_katakana]? | [standalone_small] | [choonpu]""" |
| 218 | + mora_pattern = r'[ア-ンヴ][ャュョァィゥェォヮ]?|[ァィゥェォヵヶッャュョヮ]|ー' |
| 219 | + return re.findall(mora_pattern, katakana) |
| 220 | + |
| 221 | + def _get_pitch_pattern(self, acc: int, total_mora: int) -> List[int]: |
| 222 | + """Calculate pitch pattern for entire word chain. |
| 223 | +
|
| 224 | + Args: |
| 225 | + acc: Accent nucleus position from first word in chain |
| 226 | + total_mora: Total mora count of entire chain |
| 227 | +
|
| 228 | + Returns: |
| 229 | + List of pitch values (0=low, 1=high) for each mora |
| 230 | + """ |
| 231 | + if total_mora == 0: |
| 232 | + return [] |
| 233 | + |
| 234 | + if acc == 0: # Heiban: L-H-H-H... |
| 235 | + return [0] + [1] * (total_mora - 1) |
| 236 | + |
| 237 | + if acc == 1: # Atamadaka: H-L-L-L... |
| 238 | + return [1] + [0] * (total_mora - 1) |
| 239 | + |
| 240 | + if acc >= total_mora: # Odaka: L-H-H-H |
| 241 | + return [0] + [1] * (total_mora - 1) |
| 242 | + |
| 243 | + # Nakadaka: L-H...H-L...L (drop after acc-th mora) |
| 244 | + return [0] + [1] * (acc - 1) + [0] * (total_mora - acc) |
| 245 | + |
| 246 | + def _process_chain(self, chain: List[Dict], result: List[str]) -> None: |
| 247 | + if not chain: |
| 248 | + return |
| 249 | + |
| 250 | + # Find chain starter |
| 251 | + chain_starter_idx = 0 |
| 252 | + for i, word in enumerate(chain): |
| 253 | + if word['chain_flag'] != 1: |
| 254 | + chain_starter_idx = i |
| 255 | + break |
| 256 | + |
| 257 | + chain_acc = chain[chain_starter_idx]['acc'] |
| 258 | + |
| 259 | + # Split all words into moras |
| 260 | + all_moras = [] |
| 261 | + for word in chain: |
| 262 | + moras = self._split_katakana_to_moras(word['pron']) |
| 263 | + all_moras.extend(moras) |
| 264 | + |
| 265 | + # Calculate pitch pattern using chain starter's accent |
| 266 | + total_mora = len(all_moras) |
| 267 | + pitch_pattern = self._get_pitch_pattern(chain_acc, total_mora) |
| 268 | + |
| 269 | + # Build output |
| 270 | + for mora, pitch in zip(all_moras, pitch_pattern): |
| 271 | + result.append(str(pitch)) |
| 272 | + result.extend(mora) |
| 273 | + |
| 274 | + def __call__(self, text: str) -> List[str]: |
| 275 | + """Convert Japanese text to kana with pitch accent markers. |
| 276 | +
|
| 277 | + For example, The text "こんにちは" would be converted as a list, |
| 278 | + `['0', 'コ', '1', 'ン', '1', 'ニ', '1', 'チ', '1', 'ワ']` |
| 279 | + """ |
| 280 | + text = set_grapheme_case(text, case=self.ascii_letter_case) |
| 281 | + |
| 282 | + # njd (Nihongo Jisho Data): List of word dictionaries with linguistic features |
| 283 | + njd = pyopenjtalk.run_frontend(text) |
| 284 | + |
| 285 | + result = [] |
| 286 | + current_chain = [] |
| 287 | + punctuation = self.punctuation |
| 288 | + |
| 289 | + for idx, word in enumerate(njd): |
| 290 | + if not isinstance(word, dict): |
| 291 | + continue |
| 292 | + |
| 293 | + pron = word.get('pron', '') |
| 294 | + pos = word.get('pos', '') |
| 295 | + string = word.get('string', '') |
| 296 | + chain_flag = word.get('chain_flag', 0) |
| 297 | + mora_size = word.get('mora_size', 0) |
| 298 | + acc = word.get('acc', 0) |
| 299 | + |
| 300 | + string = unicodedata.normalize('NFKC', string) |
| 301 | + |
| 302 | + # Handle English letters |
| 303 | + if string and all(c in self.ascii_letter_dict for c in string): |
| 304 | + if current_chain: |
| 305 | + self._process_chain(current_chain, result) |
| 306 | + current_chain = [] |
| 307 | + |
| 308 | + result.extend(list(string)) |
| 309 | + continue |
| 310 | + |
| 311 | + # Handle punctuation |
| 312 | + if pos in ('記号', '補助記号'): |
| 313 | + if current_chain: |
| 314 | + self._process_chain(current_chain, result) |
| 315 | + current_chain = [] |
| 316 | + if string.isspace(): |
| 317 | + result.append(' ') |
| 318 | + elif string in punctuation: |
| 319 | + result.append(string) |
| 320 | + continue |
| 321 | + |
| 322 | + if not pron or mora_size == 0: |
| 323 | + continue |
| 324 | + |
| 325 | + # Add word to current chain |
| 326 | + current_chain.append( |
| 327 | + { |
| 328 | + 'pron': pron, |
| 329 | + 'acc': acc, |
| 330 | + 'mora_size': mora_size, |
| 331 | + 'chain_flag': chain_flag, |
| 332 | + } |
| 333 | + ) |
| 334 | + |
| 335 | + # Check if next word continues chain |
| 336 | + next_has_chain = ( |
| 337 | + idx + 1 < len(njd) and isinstance(njd[idx + 1], dict) and njd[idx + 1].get('chain_flag', 0) == 1 |
| 338 | + ) |
| 339 | + |
| 340 | + # If next word doesn't continue chain, process current chain |
| 341 | + if not next_has_chain: |
| 342 | + self._process_chain(current_chain, result) |
| 343 | + current_chain = [] |
| 344 | + |
| 345 | + # Process any remaining chain |
| 346 | + if current_chain: |
| 347 | + self._process_chain(current_chain, result) |
| 348 | + return result |
0 commit comments