@@ -73,6 +73,13 @@ def allowed_value(
73
73
return True , value
74
74
75
75
76
+ def string_to_bool (val ) -> bool :
77
+ try :
78
+ return {"true" : True , "1" : True , "0" : False , "false" : False }[val .lower ()]
79
+ except KeyError :
80
+ return ValueError
81
+
82
+
76
83
def search_metadata (
77
84
param_name : str , value : Union [str , bool , float , list [float ]], netuid : int , metadata
78
85
) -> tuple [bool , Optional [dict ]]:
@@ -91,12 +98,6 @@ def search_metadata(
91
98
92
99
"""
93
100
94
- def string_to_bool (val ) -> bool :
95
- try :
96
- return {"true" : True , "1" : True , "0" : False , "false" : False }[val .lower ()]
97
- except KeyError :
98
- return ValueError
99
-
100
101
def type_converter_with_retry (type_ , val , arg_name ):
101
102
try :
102
103
if val is None :
@@ -112,31 +113,49 @@ def type_converter_with_retry(type_, val, arg_name):
112
113
113
114
call_crafter = {"netuid" : netuid }
114
115
115
- for pallet in metadata .pallets :
116
- if pallet .name == "AdminUtils" :
117
- for call in pallet .calls :
118
- if call .name == param_name :
119
- if "netuid" not in [x .name for x in call .args ]:
120
- return False , None
121
- call_args = [
122
- arg for arg in call .args if arg .value ["name" ] != "netuid"
123
- ]
124
- if len (call_args ) == 1 :
125
- arg = call_args [0 ].value
126
- call_crafter [arg ["name" ]] = type_converter_with_retry (
127
- arg ["typeName" ], value , arg ["name" ]
128
- )
129
- else :
130
- for arg_ in call_args :
131
- arg = arg_ .value
132
- call_crafter [arg ["name" ]] = type_converter_with_retry (
133
- arg ["typeName" ], None , arg ["name" ]
134
- )
135
- return True , call_crafter
116
+ pallet = metadata .get_metadata_pallet ("AdminUtils" )
117
+ for call in pallet .calls :
118
+ if call .name == param_name :
119
+ if "netuid" not in [x .name for x in call .args ]:
120
+ return False , None
121
+ call_args = [arg for arg in call .args if arg .value ["name" ] != "netuid" ]
122
+ if len (call_args ) == 1 :
123
+ arg = call_args [0 ].value
124
+ call_crafter [arg ["name" ]] = type_converter_with_retry (
125
+ arg ["typeName" ], value , arg ["name" ]
126
+ )
127
+ else :
128
+ for arg_ in call_args :
129
+ arg = arg_ .value
130
+ call_crafter [arg ["name" ]] = type_converter_with_retry (
131
+ arg ["typeName" ], None , arg ["name" ]
132
+ )
133
+ return True , call_crafter
136
134
else :
137
135
return False , None
138
136
139
137
138
+ def requires_bool (metadata , param_name ) -> bool :
139
+ """
140
+ Determines whether a given hyperparam takes a single arg (besides netuid) that is of bool type.
141
+ """
142
+ pallet = metadata .get_metadata_pallet ("AdminUtils" )
143
+ for call in pallet .calls :
144
+ if call .name == param_name :
145
+ if "netuid" not in [x .name for x in call .args ]:
146
+ return False , None
147
+ call_args = [arg for arg in call .args if arg .value ["name" ] != "netuid" ]
148
+ if len (call_args ) != 1 :
149
+ return False
150
+ else :
151
+ arg = call_args [0 ].value
152
+ if arg ["typeName" ] == "bool" :
153
+ return True
154
+ else :
155
+ return False
156
+ raise ValueError (f"{ param_name } not found in pallet." )
157
+
158
+
140
159
async def set_hyperparameter_extrinsic (
141
160
subtensor : "SubtensorInterface" ,
142
161
wallet : "Wallet" ,
@@ -221,15 +240,20 @@ async def set_hyperparameter_extrinsic(
221
240
]
222
241
223
242
if len (value ) < len (non_netuid_fields ):
224
- raise ValueError (
243
+ err_console . print (
225
244
"Not enough values provided in the list for all parameters"
226
245
)
246
+ return False
227
247
228
248
call_params .update (
229
249
{str (name ): val for name , val in zip (non_netuid_fields , value )}
230
250
)
231
251
232
252
else :
253
+ if requires_bool (
254
+ substrate .metadata , param_name = extrinsic
255
+ ) and isinstance (value , str ):
256
+ value = string_to_bool (value )
233
257
value_argument = extrinsic_params ["fields" ][
234
258
len (extrinsic_params ["fields" ]) - 1
235
259
]
@@ -252,12 +276,13 @@ async def set_hyperparameter_extrinsic(
252
276
)
253
277
if not success :
254
278
err_console .print (f":cross_mark: [red]Failed[/red]: { err_msg } " )
255
- await asyncio . sleep ( 0.5 )
279
+ return False
256
280
elif arbitrary_extrinsic :
257
281
console .print (
258
282
f":white_heavy_check_mark: "
259
283
f"[dark_sea_green3]Hyperparameter { parameter } values changed to { call_params } [/dark_sea_green3]"
260
284
)
285
+ return True
261
286
# Successful registration, final check for membership
262
287
else :
263
288
console .print (
@@ -649,17 +674,18 @@ async def get_hyperparameters(
649
674
)
650
675
dict_out = []
651
676
652
- normalized_values = normalize_hyperparameters (subnet )
653
-
677
+ normalized_values = normalize_hyperparameters (subnet , json_output = json_output )
654
678
for param , value , norm_value in normalized_values :
655
- table .add_row (" " + param , value , norm_value )
656
- dict_out .append (
657
- {
658
- "hyperparameter" : param ,
659
- "value" : value ,
660
- "normalized_value" : norm_value ,
661
- }
662
- )
679
+ if not json_output :
680
+ table .add_row (" " + param , value , norm_value )
681
+ else :
682
+ dict_out .append (
683
+ {
684
+ "hyperparameter" : param ,
685
+ "value" : value ,
686
+ "normalized_value" : norm_value ,
687
+ }
688
+ )
663
689
if json_output :
664
690
json_console .print (json .dumps (dict_out ))
665
691
else :
0 commit comments