88
99import inspect
1010import re
11- from typing import Any , Callable , Dict , List , Optional , Tuple , Type , TypeVar , Union
11+ from typing import Any , Callable , Optional , Tuple , TypeVar , Union
1212
13- import typing_inspect
1413
15-
16- def to_list (arg : str ) -> List [str ]:
14+ def to_list (arg : str ) -> list [str ]:
1715 conf = []
1816 if len (arg .strip ()) == 0 :
1917 return []
@@ -22,9 +20,9 @@ def to_list(arg: str) -> List[str]:
2220 return conf
2321
2422
25- def to_dict (arg : str ) -> Dict [str , str ]:
23+ def to_dict (arg : str ) -> dict [str , str ]:
2624 """
27- Parses the given ``arg`` string literal into a ``Dict [str, str]`` of
25+ Parses the given ``arg`` string literal into a ``dict [str, str]`` of
2826 key-value pairs delimited by ``"="`` (equals). The values may be a
2927 list literal where the list elements are delimited by ``","`` (comma)
3028 or ``";"`` (semi-colon). The same delimiters (``","`` and ``";"``) are used
@@ -85,14 +83,14 @@ def to_val(val: str) -> str:
8583 return val [1 :- 1 ]
8684 return val if val != '""' and val != "''" else ""
8785
88- arg_map : Dict [str , str ] = {}
86+ arg_map : dict [str , str ] = {}
8987
9088 if not arg :
9189 return arg_map
9290
9391 # find quoted values
9492 quoted_pattern = r'([\'"])((?:\\.|(?!\1).)*?)\1'
95- quoted_values : List [str ] = []
93+ quoted_values : list [str ] = []
9694
9795 def replace_quoted (match ):
9896 quoted_values .append (match .group (0 ))
@@ -133,19 +131,26 @@ def replace_quoted(match):
133131
134132# pyre-ignore-all-errors[3, 2]
135133def _decode_string_to_dict (
136- encoded_value : str , param_type : Type [Dict [Any , Any ]]
137- ) -> Dict [Any , Any ]:
138- key_type , value_type = typing_inspect .get_args (param_type )
134+ encoded_value : str , param_type : type [dict [Any , Any ]]
135+ ) -> dict [Any , Any ]:
136+ # pyre-ignore[16]
137+ if not hasattr (param_type , "__args__" ) or len (param_type .__args__ ) != 2 :
138+ raise ValueError (f"param_type must be a `dict` type, but was `{ param_type } `" )
139+
140+ key_type , value_type = param_type .__args__
139141 arg_values = {}
140142 for key , value in to_dict (encoded_value ).items ():
141143 arg_values [key_type (key )] = value_type (value )
142144 return arg_values
143145
144146
145147def _decode_string_to_list (
146- encoded_value : str , param_type : Type [List [Any ]]
147- ) -> List [Any ]:
148- value_type = typing_inspect .get_args (param_type )[0 ]
148+ encoded_value : str , param_type : type [list [Any ]]
149+ ) -> list [Any ]:
150+ # pyre-ignore[16]
151+ if not hasattr (param_type , "__args__" ) or len (param_type .__args__ ) != 1 :
152+ raise ValueError (f"param_type must be a `list` type, but was `{ param_type } `" )
153+ value_type = param_type .__args__ [0 ]
149154 if not is_primitive (value_type ):
150155 raise ValueError ("List types support only primitives: int, str, float" )
151156 arg_values = []
@@ -166,7 +171,7 @@ def decode(encoded_value: Any, annotation: Any):
166171
167172def decode_from_string (
168173 encoded_value : str , annotation : Any
169- ) -> Union [Dict [Any , Any ], List [Any ], None ]:
174+ ) -> Union [dict [Any , Any ], list [Any ], None ]:
170175 """Decodes string representation to the underlying type(Dict or List)
171176
172177 Given a string representation of the value, the method decodes it according
@@ -191,13 +196,13 @@ def decode_from_string(
191196 if not encoded_value :
192197 return None
193198 value_type = annotation
194- value_origin = typing_inspect . get_origin (value_type )
195- if value_origin is dict :
196- return _decode_string_to_dict ( encoded_value , value_type )
197- elif value_origin is list :
198- return _decode_string_to_list ( encoded_value , value_type )
199- else :
200- raise ValueError ("Unknown" )
199+ if hasattr (value_type , "__origin__" ):
200+ value_origin = value_type . __origin__
201+ if value_origin is dict :
202+ return _decode_string_to_dict ( encoded_value , value_type )
203+ elif value_origin is list :
204+ return _decode_string_to_list ( encoded_value , value_type )
205+ raise ValueError ("Unknown" )
201206
202207
203208def is_bool (param_type : Any ) -> bool :
@@ -229,12 +234,13 @@ def decode_optional(param_type: Any) -> Any:
229234 If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
230235 Otherwise returns ``param_type``
231236 """
232- param_origin = typing_inspect .get_origin (param_type )
233- if param_origin is not Union :
237+ if not hasattr (param_type , "__origin__" ):
238+ return param_type
239+ if param_type .__origin__ is not Union :
234240 return param_type
235- key_type , value_type = typing_inspect . get_args ( param_type )
236- if value_type is type (None ):
237- return key_type
241+ args = param_type . __args__
242+ if len ( args ) == 2 and args [ 1 ] is type (None ):
243+ return args [ 0 ]
238244 else :
239245 return param_type
240246
0 commit comments