@@ -245,7 +245,6 @@ def process_weights_for_netuid(
245
245
"""
246
246
247
247
logging .debug ("process_weights_for_netuid()" )
248
- logging .debug (f"weights: { weights } " )
249
248
logging .debug (f"netuid { netuid } " )
250
249
logging .debug (f"subtensor: { subtensor } " )
251
250
logging .debug (f"metagraph: { metagraph } " )
@@ -254,6 +253,48 @@ def process_weights_for_netuid(
254
253
if metagraph is None :
255
254
metagraph = subtensor .metagraph (netuid )
256
255
256
+ return process_weights (
257
+ uids = uids ,
258
+ weights = weights ,
259
+ num_neurons = metagraph .n ,
260
+ min_allowed_weights = subtensor .min_allowed_weights (netuid = netuid ),
261
+ max_weight_limit = subtensor .max_weight_limit (netuid = netuid ),
262
+ exclude_quantile = exclude_quantile ,
263
+ )
264
+
265
+
266
+ def process_weights (
267
+ uids : Union [NDArray [np .int64 ], "torch.Tensor" ],
268
+ weights : Union [NDArray [np .float32 ], "torch.Tensor" ],
269
+ num_neurons : int ,
270
+ min_allowed_weights : Optional [int ],
271
+ max_weight_limit : Optional [float ],
272
+ exclude_quantile : int = 0 ,
273
+ ) -> Union [
274
+ tuple ["torch.Tensor" , "torch.FloatTensor" ],
275
+ tuple [NDArray [np .int64 ], NDArray [np .float32 ]],
276
+ ]:
277
+ """
278
+ Processes weight tensors for a given weights and UID arrays and hyperparams, applying constraints
279
+ and normalization based on the subtensor and metagraph data. This function can handle both NumPy arrays and PyTorch
280
+ tensors.
281
+
282
+ Args:
283
+ uids (Union[NDArray[np.int64], "torch.Tensor"]): Array of unique identifiers of the neurons.
284
+ weights (Union[NDArray[np.float32], "torch.Tensor"]): Array of weights associated with the user IDs.
285
+ num_neurons (int): The number of neurons in the network.
286
+ min_allowed_weights (Optional[int]): Subnet hyperparam Minimum number of allowed weights.
287
+ max_weight_limit (Optional[float]): Subnet hyperparam Maximum weight limit.
288
+ exclude_quantile (int): Quantile threshold for excluding lower weights. Defaults to ``0``.
289
+
290
+ Returns:
291
+ Union[tuple["torch.Tensor", "torch.FloatTensor"], tuple[NDArray[np.int64], NDArray[np.float32]]]: tuple
292
+ containing the array of user IDs and the corresponding normalized weights. The data type of the return
293
+ matches the type of the input weights (NumPy or PyTorch).
294
+ """
295
+ logging .debug ("process_weights()" )
296
+ logging .debug (f"weights: { weights } " )
297
+
257
298
# Cast weights to floats.
258
299
if use_torch ():
259
300
if not isinstance (weights , torch .FloatTensor ):
@@ -265,8 +306,6 @@ def process_weights_for_netuid(
265
306
# Network configuration parameters from an subtensor.
266
307
# These parameters determine the range of acceptable weights for each neuron.
267
308
quantile = exclude_quantile / U16_MAX
268
- min_allowed_weights = subtensor .min_allowed_weights (netuid = netuid )
269
- max_weight_limit = subtensor .max_weight_limit (netuid = netuid )
270
309
logging .debug (f"quantile: { quantile } " )
271
310
logging .debug (f"min_allowed_weights: { min_allowed_weights } " )
272
311
logging .debug (f"max_weight_limit: { max_weight_limit } " )
@@ -280,12 +319,12 @@ def process_weights_for_netuid(
280
319
non_zero_weight_uids = uids [non_zero_weight_idx ]
281
320
non_zero_weights = weights [non_zero_weight_idx ]
282
321
nzw_size = non_zero_weights .numel () if use_torch () else non_zero_weights .size
283
- if nzw_size == 0 or metagraph . n < min_allowed_weights :
322
+ if nzw_size == 0 or num_neurons < min_allowed_weights :
284
323
logging .warning ("No non-zero weights returning all ones." )
285
324
final_weights = (
286
- torch .ones (( metagraph . n )) .to (metagraph . n ) / metagraph . n
325
+ torch .ones (num_neurons ) .to (num_neurons ) / num_neurons
287
326
if use_torch ()
288
- else np .ones (( metagraph . n ) , dtype = np .int64 ) / metagraph . n
327
+ else np .ones (num_neurons , dtype = np .int64 ) / num_neurons
289
328
)
290
329
logging .debug (f"final_weights: { final_weights } " )
291
330
final_weights_count = (
@@ -303,11 +342,11 @@ def process_weights_for_netuid(
303
342
logging .warning (
304
343
"No non-zero weights less then min allowed weight, returning all ones."
305
344
)
306
- # ( const ): Should this be np.zeros( ( metagraph.n ) ) to reset everyone to build up weight?
345
+ # ( const ): Should this be np.zeros( ( num_neurons ) ) to reset everyone to build up weight?
307
346
weights = (
308
- torch .ones (( metagraph . n )) .to (metagraph . n ) * 1e-5
347
+ torch .ones (num_neurons ) .to (num_neurons ) * 1e-5
309
348
if use_torch ()
310
- else np .ones (( metagraph . n ) , dtype = np .int64 ) * 1e-5
349
+ else np .ones (num_neurons , dtype = np .int64 ) * 1e-5
311
350
) # creating minimum even non-zero weights
312
351
weights [non_zero_weight_idx ] += non_zero_weights
313
352
logging .debug (f"final_weights: { weights } " )
0 commit comments