Skip to content

Commit 40d81d1

Browse files
committed
Bug fixes + optmisation
1 parent b1f510d commit 40d81d1

File tree

2 files changed

+72
-41
lines changed

2 files changed

+72
-41
lines changed

bittensor_cli/src/bittensor/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ def millify_tao(n: float, start_at: str = "K") -> str:
719719

720720
def normalize_hyperparameters(
721721
subnet: "SubnetHyperparameters",
722+
json_output: bool = False,
722723
) -> list[tuple[str, str, str]]:
723724
"""
724725
Normalizes the hyperparameters of a subnet.
@@ -750,13 +751,17 @@ def normalize_hyperparameters(
750751
norm_value = param_mappings[param](value)
751752
if isinstance(norm_value, float):
752753
norm_value = f"{norm_value:.{10}g}"
754+
if isinstance(norm_value, Balance) and json_output:
755+
norm_value = norm_value.to_dict()
753756
else:
754757
norm_value = value
755758
except Exception:
756759
# bittensor.logging.warning(f"Error normalizing parameter '{param}': {e}")
757760
norm_value = "-"
758-
759-
normalized_values.append((param, str(value), str(norm_value)))
761+
if not json_output:
762+
normalized_values.append((param, str(value), str(norm_value)))
763+
else:
764+
normalized_values.append((param, value, norm_value))
760765

761766
return normalized_values
762767

bittensor_cli/src/commands/sudo.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def allowed_value(
7373
return True, value
7474

7575

76+
def string_to_bool(val) -> bool:
77+
try:
78+
return {"true": True, "1": True, "0": False, "false": False}[val.lower()]
79+
except KeyError:
80+
return ValueError
81+
82+
7683
def search_metadata(
7784
param_name: str, value: Union[str, bool, float, list[float]], netuid: int, metadata
7885
) -> tuple[bool, Optional[dict]]:
@@ -91,12 +98,6 @@ def search_metadata(
9198
9299
"""
93100

94-
def string_to_bool(val) -> bool:
95-
try:
96-
return {"true": True, "1": True, "0": False, "false": False}[val.lower()]
97-
except KeyError:
98-
return ValueError
99-
100101
def type_converter_with_retry(type_, val, arg_name):
101102
try:
102103
if val is None:
@@ -112,31 +113,49 @@ def type_converter_with_retry(type_, val, arg_name):
112113

113114
call_crafter = {"netuid": netuid}
114115

115-
for pallet in metadata.pallets:
116-
if pallet.name == "AdminUtils":
117-
for call in pallet.calls:
118-
if call.name == param_name:
119-
if "netuid" not in [x.name for x in call.args]:
120-
return False, None
121-
call_args = [
122-
arg for arg in call.args if arg.value["name"] != "netuid"
123-
]
124-
if len(call_args) == 1:
125-
arg = call_args[0].value
126-
call_crafter[arg["name"]] = type_converter_with_retry(
127-
arg["typeName"], value, arg["name"]
128-
)
129-
else:
130-
for arg_ in call_args:
131-
arg = arg_.value
132-
call_crafter[arg["name"]] = type_converter_with_retry(
133-
arg["typeName"], None, arg["name"]
134-
)
135-
return True, call_crafter
116+
pallet = metadata.get_metadata_pallet("AdminUtils")
117+
for call in pallet.calls:
118+
if call.name == param_name:
119+
if "netuid" not in [x.name for x in call.args]:
120+
return False, None
121+
call_args = [arg for arg in call.args if arg.value["name"] != "netuid"]
122+
if len(call_args) == 1:
123+
arg = call_args[0].value
124+
call_crafter[arg["name"]] = type_converter_with_retry(
125+
arg["typeName"], value, arg["name"]
126+
)
127+
else:
128+
for arg_ in call_args:
129+
arg = arg_.value
130+
call_crafter[arg["name"]] = type_converter_with_retry(
131+
arg["typeName"], None, arg["name"]
132+
)
133+
return True, call_crafter
136134
else:
137135
return False, None
138136

139137

138+
def requires_bool(metadata, param_name) -> bool:
139+
"""
140+
Determines whether a given hyperparam takes a single arg (besides netuid) that is of bool type.
141+
"""
142+
pallet = metadata.get_metadata_pallet("AdminUtils")
143+
for call in pallet.calls:
144+
if call.name == param_name:
145+
if "netuid" not in [x.name for x in call.args]:
146+
return False, None
147+
call_args = [arg for arg in call.args if arg.value["name"] != "netuid"]
148+
if len(call_args) != 1:
149+
return False
150+
else:
151+
arg = call_args[0].value
152+
if arg["typeName"] == "bool":
153+
return True
154+
else:
155+
return False
156+
raise ValueError(f"{param_name} not found in pallet.")
157+
158+
140159
async def set_hyperparameter_extrinsic(
141160
subtensor: "SubtensorInterface",
142161
wallet: "Wallet",
@@ -221,15 +240,20 @@ async def set_hyperparameter_extrinsic(
221240
]
222241

223242
if len(value) < len(non_netuid_fields):
224-
raise ValueError(
243+
err_console.print(
225244
"Not enough values provided in the list for all parameters"
226245
)
246+
return False
227247

228248
call_params.update(
229249
{str(name): val for name, val in zip(non_netuid_fields, value)}
230250
)
231251

232252
else:
253+
if requires_bool(
254+
substrate.metadata, param_name=extrinsic
255+
) and isinstance(value, str):
256+
value = string_to_bool(value)
233257
value_argument = extrinsic_params["fields"][
234258
len(extrinsic_params["fields"]) - 1
235259
]
@@ -252,12 +276,13 @@ async def set_hyperparameter_extrinsic(
252276
)
253277
if not success:
254278
err_console.print(f":cross_mark: [red]Failed[/red]: {err_msg}")
255-
await asyncio.sleep(0.5)
279+
return False
256280
elif arbitrary_extrinsic:
257281
console.print(
258282
f":white_heavy_check_mark: "
259283
f"[dark_sea_green3]Hyperparameter {parameter} values changed to {call_params}[/dark_sea_green3]"
260284
)
285+
return True
261286
# Successful registration, final check for membership
262287
else:
263288
console.print(
@@ -649,17 +674,18 @@ async def get_hyperparameters(
649674
)
650675
dict_out = []
651676

652-
normalized_values = normalize_hyperparameters(subnet)
653-
677+
normalized_values = normalize_hyperparameters(subnet, json_output=json_output)
654678
for param, value, norm_value in normalized_values:
655-
table.add_row(" " + param, value, norm_value)
656-
dict_out.append(
657-
{
658-
"hyperparameter": param,
659-
"value": value,
660-
"normalized_value": norm_value,
661-
}
662-
)
679+
if not json_output:
680+
table.add_row(" " + param, value, norm_value)
681+
else:
682+
dict_out.append(
683+
{
684+
"hyperparameter": param,
685+
"value": value,
686+
"normalized_value": norm_value,
687+
}
688+
)
663689
if json_output:
664690
json_console.print(json.dumps(dict_out))
665691
else:

0 commit comments

Comments
 (0)