diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py index 3ce3d246dfc..110f8bfd8e3 100644 --- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py +++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py @@ -41,6 +41,33 @@ model_version = '1.1' +def get_g2pw_model_path(model_dir: os.PathLike, model_version: str) -> str: + """Resolve the G2PW ONNX model directory path. + + Checks if the model file 'g2pW.onnx' exists in the expected location. + If not, downloads and decompresses the model archive + + Args: + model_dir (os.PathLike): Base directory to store models (e.g., ~/.paddlespeech). + model_version (str): Model version string (e.g., '1.1'). + + Returns: + str: Path to the model directory containing 'g2pW.onnx'. + """ + + archive_info = g2pw_onnx_models['G2PWModel'][model_version] + archive_fname = os.path.basename( + archive_info['url']) # e.g., "G2PWModel_1.1.zip" + expected_extract_name = os.path.splitext(archive_fname)[ + 0] # e.g., "G2PWModel_1.1" + expected_model_dir = os.path.join(model_dir, expected_extract_name) + uncompress_path = expected_model_dir + onnx_file_path = os.path.join(expected_model_dir, 'g2pW.onnx') + if not os.path.isfile(onnx_file_path): + uncompress_path = download_and_decompress(archive_info, model_dir) + return uncompress_path + + def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]: all_preds = [] @@ -70,8 +97,7 @@ def __init__(self, style: str='bopomofo', model_source: str=None, enable_non_tradional_chinese: bool=False): - uncompress_path = download_and_decompress( - g2pw_onnx_models['G2PWModel'][model_version], model_dir) + uncompress_path = get_g2pw_model_path(model_dir, model_version) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL