Skip to content

Commit de7d98f

Browse files
authored
Merge pull request #2798 from opentensor/feat/thewhaleking/process_weights_for_netuid
process weights for netuid
2 parents 5d68312 + 7aec7ae commit de7d98f

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

bittensor/utils/weight_utils.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def process_weights_for_netuid(
245245
"""
246246

247247
logging.debug("process_weights_for_netuid()")
248-
logging.debug(f"weights: {weights}")
249248
logging.debug(f"netuid {netuid}")
250249
logging.debug(f"subtensor: {subtensor}")
251250
logging.debug(f"metagraph: {metagraph}")
@@ -254,6 +253,48 @@ def process_weights_for_netuid(
254253
if metagraph is None:
255254
metagraph = subtensor.metagraph(netuid)
256255

256+
return process_weights(
257+
uids=uids,
258+
weights=weights,
259+
num_neurons=metagraph.n,
260+
min_allowed_weights=subtensor.min_allowed_weights(netuid=netuid),
261+
max_weight_limit=subtensor.max_weight_limit(netuid=netuid),
262+
exclude_quantile=exclude_quantile,
263+
)
264+
265+
266+
def process_weights(
267+
uids: Union[NDArray[np.int64], "torch.Tensor"],
268+
weights: Union[NDArray[np.float32], "torch.Tensor"],
269+
num_neurons: int,
270+
min_allowed_weights: Optional[int],
271+
max_weight_limit: Optional[float],
272+
exclude_quantile: int = 0,
273+
) -> Union[
274+
tuple["torch.Tensor", "torch.FloatTensor"],
275+
tuple[NDArray[np.int64], NDArray[np.float32]],
276+
]:
277+
"""
278+
Processes weight tensors for a given weights and UID arrays and hyperparams, applying constraints
279+
and normalization based on the subtensor and metagraph data. This function can handle both NumPy arrays and PyTorch
280+
tensors.
281+
282+
Args:
283+
uids (Union[NDArray[np.int64], "torch.Tensor"]): Array of unique identifiers of the neurons.
284+
weights (Union[NDArray[np.float32], "torch.Tensor"]): Array of weights associated with the user IDs.
285+
num_neurons (int): The number of neurons in the network.
286+
min_allowed_weights (Optional[int]): Subnet hyperparam Minimum number of allowed weights.
287+
max_weight_limit (Optional[float]): Subnet hyperparam Maximum weight limit.
288+
exclude_quantile (int): Quantile threshold for excluding lower weights. Defaults to ``0``.
289+
290+
Returns:
291+
Union[tuple["torch.Tensor", "torch.FloatTensor"], tuple[NDArray[np.int64], NDArray[np.float32]]]: tuple
292+
containing the array of user IDs and the corresponding normalized weights. The data type of the return
293+
matches the type of the input weights (NumPy or PyTorch).
294+
"""
295+
logging.debug("process_weights()")
296+
logging.debug(f"weights: {weights}")
297+
257298
# Cast weights to floats.
258299
if use_torch():
259300
if not isinstance(weights, torch.FloatTensor):
@@ -265,8 +306,6 @@ def process_weights_for_netuid(
265306
# Network configuration parameters from an subtensor.
266307
# These parameters determine the range of acceptable weights for each neuron.
267308
quantile = exclude_quantile / U16_MAX
268-
min_allowed_weights = subtensor.min_allowed_weights(netuid=netuid)
269-
max_weight_limit = subtensor.max_weight_limit(netuid=netuid)
270309
logging.debug(f"quantile: {quantile}")
271310
logging.debug(f"min_allowed_weights: {min_allowed_weights}")
272311
logging.debug(f"max_weight_limit: {max_weight_limit}")
@@ -280,12 +319,12 @@ def process_weights_for_netuid(
280319
non_zero_weight_uids = uids[non_zero_weight_idx]
281320
non_zero_weights = weights[non_zero_weight_idx]
282321
nzw_size = non_zero_weights.numel() if use_torch() else non_zero_weights.size
283-
if nzw_size == 0 or metagraph.n < min_allowed_weights:
322+
if nzw_size == 0 or num_neurons < min_allowed_weights:
284323
logging.warning("No non-zero weights returning all ones.")
285324
final_weights = (
286-
torch.ones((metagraph.n)).to(metagraph.n) / metagraph.n
325+
torch.ones(num_neurons).to(num_neurons) / num_neurons
287326
if use_torch()
288-
else np.ones((metagraph.n), dtype=np.int64) / metagraph.n
327+
else np.ones(num_neurons, dtype=np.int64) / num_neurons
289328
)
290329
logging.debug(f"final_weights: {final_weights}")
291330
final_weights_count = (
@@ -303,11 +342,11 @@ def process_weights_for_netuid(
303342
logging.warning(
304343
"No non-zero weights less then min allowed weight, returning all ones."
305344
)
306-
# ( const ): Should this be np.zeros( ( metagraph.n ) ) to reset everyone to build up weight?
345+
# ( const ): Should this be np.zeros( ( num_neurons ) ) to reset everyone to build up weight?
307346
weights = (
308-
torch.ones((metagraph.n)).to(metagraph.n) * 1e-5
347+
torch.ones(num_neurons).to(num_neurons) * 1e-5
309348
if use_torch()
310-
else np.ones((metagraph.n), dtype=np.int64) * 1e-5
349+
else np.ones(num_neurons, dtype=np.int64) * 1e-5
311350
) # creating minimum even non-zero weights
312351
weights[non_zero_weight_idx] += non_zero_weights
313352
logging.debug(f"final_weights: {weights}")

0 commit comments

Comments
 (0)