99import logging
1010from ast import literal_eval
1111from pathlib import Path
12- from typing import Dict , Optional
12+ from typing import Dict , Literal , Optional
1313
1414import msgpack
1515import numpy as np
@@ -211,7 +211,7 @@ def _extract_msgpack_data(data: bytes, **kwargs):
211211 return input_data , extensions
212212
213213
214- def _get_serialization_path (path : Path , format_type : str = "auto" ) -> Path :
214+ def _get_serialization_path (path : Path , format_type : Literal [ "json" , "msgpack" , "auto" ] = "auto" ) -> Path :
215215 """Get the correct path for serialization format.
216216
217217 Args:
@@ -221,19 +221,22 @@ def _get_serialization_path(path: Path, format_type: str = "auto") -> Path:
221221 Returns:
222222 Path: Path with correct extension
223223 """
224+ JSON_EXTENSIONS = [".json" ]
225+ MSGPACK_EXTENSIONS = [".msgpack" , ".mp" ]
226+
224227 if format_type == "auto" :
225- if path .suffix .lower () in [ ".json" ] :
228+ if path .suffix .lower () in JSON_EXTENSIONS :
226229 format_type = "json"
227- elif path .suffix .lower () in [ ".msgpack" , ".mp" ]:
230+ elif path .suffix .lower () in MSGPACK_EXTENSIONS
228231 format_type = "msgpack"
229232 else :
230233 # Default to JSON
231234 format_type = "json"
232235
233- if format_type == "json" and path .suffix .lower () != ".json" :
234- return path .with_suffix (".json" )
235- if format_type == "msgpack" and path .suffix .lower () not in [ ".msgpack" , ".mp" ] :
236- return path .with_suffix (".msgpack" )
236+ if format_type == "json" and path .suffix .lower () != JSON_EXTENSIONS [ 0 ] :
237+ return path .with_suffix (JSON_EXTENSIONS [ 0 ] )
238+ if format_type == "msgpack" and path .suffix .lower () not in MSGPACK_EXTENSIONS :
239+ return path .with_suffix (MSGPACK_EXTENSIONS [ 0 ] )
237240
238241 return path
239242
0 commit comments