@@ -5,6 +5,7 @@ import json
55from libcpp cimport bool
66from libcpp.string cimport string
77from libcpp.vector cimport vector
8+ from libcpp.optional cimport optional
89
910cdef extern from " common/params.h" :
1011 cpdef enum ParamKeyFlag:
@@ -36,7 +37,7 @@ cdef extern from "common/params.h":
3637 int putBool(string, bool ) nogil
3738 bool checkKey(string) nogil
3839 ParamKeyType getKeyType(string) nogil
39- string getKeyDefaultValue(string) nogil
40+ optional[ string] getKeyDefaultValue(string) nogil
4041 string getParamPath(string) nogil
4142 void clearAll(ParamKeyFlag)
4243 vector[string] allKeys()
@@ -73,40 +74,45 @@ cdef class Params:
7374 raise UnknownKeyName(key)
7475 return key
7576
76- def get (self , key , bool block = False , default = None ):
77+ def cast (self , t , value , default ):
78+ if value is None :
79+ return None
80+ try :
81+ if t == STRING:
82+ return value.decode(" utf-8" )
83+ elif t == BOOL:
84+ return value == b" 1"
85+ elif t == INT:
86+ return int (value)
87+ elif t == FLOAT:
88+ return float (value)
89+ elif t == TIME:
90+ return datetime.datetime.fromisoformat(value.decode(" utf-8" ))
91+ elif t == JSON:
92+ return json.loads(value)
93+ elif t == BYTES:
94+ return value
95+ else :
96+ raise TypeError ()
97+ except (TypeError , ValueError ):
98+ return self .cast(t, default, None )
99+
100+ def get (self , key , bool block = False , bool return_default = False ):
77101 cdef string k = self .check_key(key)
78- cdef ParamKeyType t = self .p.getKeyType(ensure_bytes(key) )
102+ cdef ParamKeyType t = self .p.getKeyType(k )
79103 cdef string val
80104 with nogil:
81105 val = self .p.get(k, block)
82106
107+ default_val = self .get_default_value(k) if return_default else None
83108 if val == b" " :
84109 if block:
85110 # If we got no value while running in blocked mode
86111 # it means we got an interrupt while waiting
87112 raise KeyboardInterrupt
88113 else :
89- return default
90-
91- try :
92- if t == STRING:
93- return val.decode(" utf-8" )
94- elif t == BOOL:
95- return val == b" 1"
96- elif t == INT:
97- return int (val)
98- elif t == FLOAT:
99- return float (val)
100- elif t == TIME:
101- return datetime.datetime.fromisoformat(val.decode(" utf-8" ))
102- elif t == JSON:
103- return json.loads(val)
104- elif t == BYTES:
105- return val
106- else :
107- return default
108- except (TypeError , ValueError ):
109- return default
114+ return self .cast(t, default_val, None )
115+ return self .cast(t, val, default_val)
110116
111117 def get_bool (self , key , bool block = False ):
112118 cdef string k = self .check_key(key)
@@ -152,8 +158,12 @@ cdef class Params:
152158 cdef string key_bytes = ensure_bytes(key)
153159 return self .p.getParamPath(key_bytes).decode(" utf-8" )
154160
161+ def get_type (self , key ):
162+ return self .p.getKeyType(self .check_key(key))
163+
155164 def all_keys (self ):
156165 return self .p.allKeys()
157166
158167 def get_default_value (self , key ):
159- return self .p.getKeyDefaultValue(self .check_key(key))
168+ cdef optional[string] default = self .p.getKeyDefaultValue(self .check_key(key))
169+ return default.value() if default.has_value() else None
0 commit comments