@@ -73,6 +73,13 @@ def allowed_value(
73
73
return True , value
74
74
75
75
76
+ def string_to_bool (val ) -> bool :
77
+ try :
78
+ return {"true" : True , "1" : True , "0" : False , "false" : False }[val .lower ()]
79
+ except KeyError :
80
+ return ValueError
81
+
82
+
76
83
def search_metadata (
77
84
param_name : str , value : Union [str , bool , float , list [float ]], netuid : int , metadata
78
85
) -> tuple [bool , Optional [dict ]]:
@@ -91,12 +98,6 @@ def search_metadata(
91
98
92
99
"""
93
100
94
- def string_to_bool (val ) -> bool :
95
- try :
96
- return {"true" : True , "1" : True , "0" : False , "false" : False }[val .lower ()]
97
- except KeyError :
98
- return ValueError
99
-
100
101
def type_converter_with_retry (type_ , val , arg_name ):
101
102
try :
102
103
if val is None :
@@ -112,37 +113,55 @@ def type_converter_with_retry(type_, val, arg_name):
112
113
113
114
call_crafter = {"netuid" : netuid }
114
115
115
- for pallet in metadata .pallets :
116
- if pallet .name == "AdminUtils" :
117
- for call in pallet .calls :
118
- if call .name == param_name :
119
- if "netuid" not in [x .name for x in call .args ]:
120
- return False , None
121
- call_args = [
122
- arg for arg in call .args if arg .value ["name" ] != "netuid"
123
- ]
124
- if len (call_args ) == 1 :
125
- arg = call_args [0 ].value
126
- call_crafter [arg ["name" ]] = type_converter_with_retry (
127
- arg ["typeName" ], value , arg ["name" ]
128
- )
129
- else :
130
- for arg_ in call_args :
131
- arg = arg_ .value
132
- call_crafter [arg ["name" ]] = type_converter_with_retry (
133
- arg ["typeName" ], None , arg ["name" ]
134
- )
135
- return True , call_crafter
116
+ pallet = metadata .get_metadata_pallet ("AdminUtils" )
117
+ for call in pallet .calls :
118
+ if call .name == param_name :
119
+ if "netuid" not in [x .name for x in call .args ]:
120
+ return False , None
121
+ call_args = [arg for arg in call .args if arg .value ["name" ] != "netuid" ]
122
+ if len (call_args ) == 1 :
123
+ arg = call_args [0 ].value
124
+ call_crafter [arg ["name" ]] = type_converter_with_retry (
125
+ arg ["typeName" ], value , arg ["name" ]
126
+ )
127
+ else :
128
+ for arg_ in call_args :
129
+ arg = arg_ .value
130
+ call_crafter [arg ["name" ]] = type_converter_with_retry (
131
+ arg ["typeName" ], None , arg ["name" ]
132
+ )
133
+ return True , call_crafter
136
134
else :
137
135
return False , None
138
136
139
137
138
+ def requires_bool (metadata , param_name ) -> bool :
139
+ """
140
+ Determines whether a given hyperparam takes a single arg (besides netuid) that is of bool type.
141
+ """
142
+ pallet = metadata .get_metadata_pallet ("AdminUtils" )
143
+ for call in pallet .calls :
144
+ if call .name == param_name :
145
+ if "netuid" not in [x .name for x in call .args ]:
146
+ return False , None
147
+ call_args = [arg for arg in call .args if arg .value ["name" ] != "netuid" ]
148
+ if len (call_args ) != 1 :
149
+ return False
150
+ else :
151
+ arg = call_args [0 ].value
152
+ if arg ["typeName" ] == "bool" :
153
+ return True
154
+ else :
155
+ return False
156
+ raise ValueError (f"{ param_name } not found in pallet." )
157
+
158
+
140
159
async def set_hyperparameter_extrinsic (
141
160
subtensor : "SubtensorInterface" ,
142
161
wallet : "Wallet" ,
143
162
netuid : int ,
144
163
parameter : str ,
145
- value : Optional [Union [str , bool , float , list [float ]]],
164
+ value : Optional [Union [str , float , list [float ]]],
146
165
wait_for_inclusion : bool = False ,
147
166
wait_for_finalization : bool = True ,
148
167
prompt : bool = True ,
@@ -221,15 +240,20 @@ async def set_hyperparameter_extrinsic(
221
240
]
222
241
223
242
if len (value ) < len (non_netuid_fields ):
224
- raise ValueError (
243
+ err_console . print (
225
244
"Not enough values provided in the list for all parameters"
226
245
)
246
+ return False
227
247
228
248
call_params .update (
229
249
{str (name ): val for name , val in zip (non_netuid_fields , value )}
230
250
)
231
251
232
252
else :
253
+ if requires_bool (
254
+ substrate .metadata , param_name = extrinsic
255
+ ) and isinstance (value , str ):
256
+ value = string_to_bool (value )
233
257
value_argument = extrinsic_params ["fields" ][
234
258
len (extrinsic_params ["fields" ]) - 1
235
259
]
@@ -252,12 +276,13 @@ async def set_hyperparameter_extrinsic(
252
276
)
253
277
if not success :
254
278
err_console .print (f":cross_mark: [red]Failed[/red]: { err_msg } " )
255
- await asyncio . sleep ( 0.5 )
279
+ return False
256
280
elif arbitrary_extrinsic :
257
281
console .print (
258
282
f":white_heavy_check_mark: "
259
283
f"[dark_sea_green3]Hyperparameter { parameter } values changed to { call_params } [/dark_sea_green3]"
260
284
)
285
+ return True
261
286
# Successful registration, final check for membership
262
287
else :
263
288
console .print (
@@ -581,28 +606,11 @@ async def sudo_set_hyperparameter(
581
606
json_output : bool ,
582
607
):
583
608
"""Set subnet hyperparameters."""
584
-
585
- normalized_value : Union [str , bool ]
586
- if param_name in [
587
- "registration_allowed" ,
588
- "network_pow_registration_allowed" ,
589
- "commit_reveal_weights_enabled" ,
590
- "liquid_alpha_enabled" ,
591
- ]:
592
- normalized_value = param_value .lower () in ["true" , "1" ]
593
- elif param_value in ("True" , "False" ):
594
- normalized_value = {
595
- "True" : True ,
596
- "False" : False ,
597
- }[param_value ]
598
- else :
599
- normalized_value = param_value
600
-
601
- is_allowed_value , value = allowed_value (param_name , normalized_value )
609
+ is_allowed_value , value = allowed_value (param_name , param_value )
602
610
if not is_allowed_value :
603
611
err_console .print (
604
612
f"Hyperparameter [dark_orange]{ param_name } [/dark_orange] value is not within bounds. "
605
- f"Value is { normalized_value } but must be { value } "
613
+ f"Value is { param_value } but must be { value } "
606
614
)
607
615
return False
608
616
success = await set_hyperparameter_extrinsic (
@@ -625,8 +633,9 @@ async def get_hyperparameters(
625
633
if not await subtensor .subnet_exists (netuid ):
626
634
print_error (f"Subnet with netuid { netuid } does not exist." )
627
635
return False
628
- subnet = await subtensor .get_subnet_hyperparameters (netuid )
629
- subnet_info = await subtensor .subnet (netuid )
636
+ subnet , subnet_info = await asyncio .gather (
637
+ subtensor .get_subnet_hyperparameters (netuid ), subtensor .subnet (netuid )
638
+ )
630
639
if subnet_info is None :
631
640
print_error (f"Subnet with netuid { netuid } does not exist." )
632
641
return False
@@ -648,17 +657,18 @@ async def get_hyperparameters(
648
657
)
649
658
dict_out = []
650
659
651
- normalized_values = normalize_hyperparameters (subnet )
652
-
660
+ normalized_values = normalize_hyperparameters (subnet , json_output = json_output )
653
661
for param , value , norm_value in normalized_values :
654
- table .add_row (" " + param , value , norm_value )
655
- dict_out .append (
656
- {
657
- "hyperparameter" : param ,
658
- "value" : value ,
659
- "normalized_value" : norm_value ,
660
- }
661
- )
662
+ if not json_output :
663
+ table .add_row (" " + param , value , norm_value )
664
+ else :
665
+ dict_out .append (
666
+ {
667
+ "hyperparameter" : param ,
668
+ "value" : value ,
669
+ "normalized_value" : norm_value ,
670
+ }
671
+ )
662
672
if json_output :
663
673
json_console .print (json .dumps (dict_out ))
664
674
else :
0 commit comments