Skip to content

Commit b0a6a78

Browse files
committed
Ban saturating_div
1 parent 2ba69b1 commit b0a6a78

File tree

15 files changed

+114
-71
lines changed

15 files changed

+114
-71
lines changed

pallets/collective/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ pub mod pallet {
549549
);
550550

551551
let threshold = T::GetVotingMembers::get_count()
552-
.saturating_div(2)
552+
.checked_div(2)
553+
.unwrap_or(0)
553554
.saturating_add(1);
554555

555556
let members = Self::members();

pallets/subtensor/src/coinbase/block_emission.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::*;
2+
use crate::epoch::math::*;
23
use frame_support::traits::Get;
34
use substrate_fixed::{transcendental::log2, types::I96F32};
45

@@ -139,7 +140,7 @@ impl<T: Config> Pallet<T> {
139140
for _ in 0..floored_residual_int {
140141
multiplier = multiplier.saturating_mul(I96F32::from_num(2.0));
141142
}
142-
let block_emission_percentage: I96F32 = I96F32::from_num(1.0).saturating_div(multiplier);
143+
let block_emission_percentage: I96F32 = I96F32::from_num(1.0).safe_div(multiplier);
143144
// Calculate the actual emission based on the emission rate
144145
let block_emission: I96F32 = block_emission_percentage
145146
.saturating_mul(I96F32::from_num(DefaultBlockEmission::<T>::get()));

pallets/subtensor/src/coinbase/block_step.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::*;
2+
use crate::epoch::math::*;
23
use frame_support::storage::IterableStorageMap;
34
use substrate_fixed::types::{I110F18, I96F32};
45

@@ -202,11 +203,11 @@ impl<T: Config + pallet_drand::Config> Pallet<T> {
202203
.saturating_mul(I110F18::from_num(
203204
registrations_this_interval.saturating_add(target_registrations_per_interval),
204205
))
205-
.saturating_div(I110F18::from_num(
206+
.safe_div(I110F18::from_num(
206207
target_registrations_per_interval.saturating_add(target_registrations_per_interval),
207208
));
208209
let alpha: I110F18 = I110F18::from_num(Self::get_adjustment_alpha(netuid))
209-
.saturating_div(I110F18::from_num(u64::MAX));
210+
.safe_div(I110F18::from_num(u64::MAX));
210211
let next_value: I110F18 = alpha
211212
.saturating_mul(I110F18::from_num(current_difficulty))
212213
.saturating_add(
@@ -236,11 +237,11 @@ impl<T: Config + pallet_drand::Config> Pallet<T> {
236237
.saturating_mul(I110F18::from_num(
237238
registrations_this_interval.saturating_add(target_registrations_per_interval),
238239
))
239-
.saturating_div(I110F18::from_num(
240+
.safe_div(I110F18::from_num(
240241
target_registrations_per_interval.saturating_add(target_registrations_per_interval),
241242
));
242243
let alpha: I110F18 = I110F18::from_num(Self::get_adjustment_alpha(netuid))
243-
.saturating_div(I110F18::from_num(u64::MAX));
244+
.safe_div(I110F18::from_num(u64::MAX));
244245
let next_value: I110F18 = alpha
245246
.saturating_mul(I110F18::from_num(current_burn))
246247
.saturating_add(

pallets/subtensor/src/coinbase/root.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// DEALINGS IN THE SOFTWARE.
1717

1818
use super::*;
19+
use crate::epoch::math::*;
1920
use frame_support::dispatch::Pays;
2021
use frame_support::storage::IterableStorageDoubleMap;
2122
use frame_support::weights::Weight;
@@ -624,7 +625,7 @@ impl<T: Config> Pallet<T> {
624625

625626
let mut lock_cost = last_lock.saturating_mul(mult).saturating_sub(
626627
last_lock
627-
.saturating_div(lock_reduction_interval)
628+
.safe_div(lock_reduction_interval)
628629
.saturating_mul(current_block.saturating_sub(last_lock_block)),
629630
);
630631

pallets/subtensor/src/coinbase/run_coinbase.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::*;
2+
use crate::epoch::math::*;
23
use alloc::collections::BTreeMap;
34
use substrate_fixed::types::I96F32;
45
use tle::stream_ciphers::AESGCMStreamCipherProvider;
@@ -557,7 +558,7 @@ impl<T: Config> Pallet<T> {
557558
for (proportion, _) in childkeys {
558559
remaining_proportion = remaining_proportion.saturating_sub(
559560
I96F32::from_num(proportion) // Normalize
560-
.saturating_div(I96F32::from_num(u64::MAX)),
561+
.safe_div(I96F32::from_num(u64::MAX)),
561562
);
562563
}
563564

@@ -608,7 +609,7 @@ impl<T: Config> Pallet<T> {
608609
let validating_emission: I96F32 = I96F32::from_num(dividends);
609610
let childkey_take_proportion: I96F32 =
610611
I96F32::from_num(Self::get_childkey_take(hotkey, netuid))
611-
.saturating_div(I96F32::from_num(u16::MAX));
612+
.safe_div(I96F32::from_num(u16::MAX));
612613
log::debug!(
613614
"Childkey take proportion: {:?} for hotkey {:?}",
614615
childkey_take_proportion,
@@ -656,7 +657,7 @@ impl<T: Config> Pallet<T> {
656657
for (proportion, parent) in Self::get_parents(hotkey, netuid) {
657658
// Convert the parent's stake proportion to a fractional value
658659
let parent_proportion: I96F32 =
659-
I96F32::from_num(proportion).saturating_div(I96F32::from_num(u64::MAX));
660+
I96F32::from_num(proportion).safe_div(I96F32::from_num(u64::MAX));
660661

661662
// Get the parent's root and subnet-specific (alpha) stakes
662663
let parent_root: I96F32 = I96F32::from_num(Self::get_stake_for_hotkey_on_subnet(

pallets/subtensor/src/epoch/math.rs

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use sp_std::cmp::Ordering;
88

99
use sp_std::vec;
1010
use substrate_fixed::transcendental::{exp, ln};
11-
use substrate_fixed::types::{I32F32, I64F64};
11+
use substrate_fixed::types::{I110F18, I32F32, I64F64, I96F32, U64F64};
1212

1313
// TODO: figure out what cfg gate this needs to not be a warning in rustc
1414
#[allow(unused)]
@@ -51,7 +51,7 @@ pub fn u16_to_fixed(x: u16) -> I32F32 {
5151

5252
#[allow(dead_code)]
5353
pub fn u16_proportion_to_fixed(x: u16) -> I32F32 {
54-
I32F32::from_num(x).saturating_div(I32F32::from_num(u16::MAX))
54+
I32F32::from_num(x).safe_div(I32F32::from_num(u16::MAX))
5555
}
5656

5757
#[allow(dead_code)]
@@ -107,7 +107,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
107107
return vec
108108
.iter()
109109
.map(|e: &I32F32| {
110-
e.saturating_mul(u16_max.saturating_div(*val))
110+
e.saturating_mul(u16_max.safe_div(*val))
111111
.round()
112112
.to_num::<u16>()
113113
})
@@ -116,7 +116,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
116116
vec.iter()
117117
.map(|e: &I32F32| {
118118
e.saturating_mul(u16_max)
119-
.saturating_div(*val)
119+
.safe_div(*val)
120120
.round()
121121
.to_num::<u16>()
122122
})
@@ -125,11 +125,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
125125
None => {
126126
let sum: I32F32 = vec.iter().sum();
127127
vec.iter()
128-
.map(|e: &I32F32| {
129-
e.saturating_mul(u16_max)
130-
.saturating_div(sum)
131-
.to_num::<u16>()
132-
})
128+
.map(|e: &I32F32| e.saturating_mul(u16_max).safe_div(sum).to_num::<u16>())
133129
.collect()
134130
}
135131
}
@@ -145,8 +141,7 @@ pub fn vec_u16_max_upscale_to_u16(vec: &[u16]) -> Vec<u16> {
145141
#[allow(dead_code)]
146142
// Checks if u16 vector, when normalized, has a max value not greater than a u16 ratio max_limit.
147143
pub fn check_vec_max_limited(vec: &[u16], max_limit: u16) -> bool {
148-
let max_limit_fixed: I32F32 =
149-
I32F32::from_num(max_limit).saturating_div(I32F32::from_num(u16::MAX));
144+
let max_limit_fixed: I32F32 = I32F32::from_num(max_limit).safe_div(I32F32::from_num(u16::MAX));
150145
let mut vec_fixed: Vec<I32F32> = vec.iter().map(|e: &u16| I32F32::from_num(*e)).collect();
151146
inplace_normalize(&mut vec_fixed);
152147
let max_value: Option<&I32F32> = vec_fixed.iter().max();
@@ -219,7 +214,7 @@ pub fn sigmoid_safe(input: I32F32, rho: I32F32, kappa: I32F32) -> I32F32 {
219214
let exp_input: I32F32 = neg_rho.saturating_mul(offset); // -rho*(input-kappa)
220215
let exp_output: I32F32 = exp_safe(exp_input); // exp(-rho*(input-kappa))
221216
let denominator: I32F32 = exp_output.saturating_add(one); // 1 + exp(-rho*(input-kappa))
222-
let sigmoid_output: I32F32 = one.saturating_div(denominator); // 1 / (1 + exp(-rho*(input-kappa)))
217+
let sigmoid_output: I32F32 = one.safe_div(denominator); // 1 / (1 + exp(-rho*(input-kappa)))
223218
sigmoid_output
224219
}
225220

@@ -244,7 +239,7 @@ pub fn is_topk(vector: &[I32F32], k: usize) -> Vec<bool> {
244239
pub fn normalize(x: &[I32F32]) -> Vec<I32F32> {
245240
let x_sum: I32F32 = sum(x);
246241
if x_sum != I32F32::from_num(0.0_f32) {
247-
x.iter().map(|xi| xi.saturating_div(x_sum)).collect()
242+
x.iter().map(|xi| xi.safe_div(x_sum)).collect()
248243
} else {
249244
x.to_vec()
250245
}
@@ -258,7 +253,7 @@ pub fn inplace_normalize(x: &mut [I32F32]) {
258253
return;
259254
}
260255
x.iter_mut()
261-
.for_each(|value| *value = value.saturating_div(x_sum));
256+
.for_each(|value| *value = value.safe_div(x_sum));
262257
}
263258

264259
// Normalizes (sum to 1 except 0) the input vector directly in-place, using the sum arg.
@@ -268,7 +263,7 @@ pub fn inplace_normalize_using_sum(x: &mut [I32F32], x_sum: I32F32) {
268263
return;
269264
}
270265
x.iter_mut()
271-
.for_each(|value| *value = value.saturating_div(x_sum));
266+
.for_each(|value| *value = value.safe_div(x_sum));
272267
}
273268

274269
// Normalizes (sum to 1 except 0) the I64F64 input vector directly in-place.
@@ -279,7 +274,7 @@ pub fn inplace_normalize_64(x: &mut [I64F64]) {
279274
return;
280275
}
281276
x.iter_mut()
282-
.for_each(|value| *value = value.saturating_div(x_sum));
277+
.for_each(|value| *value = value.safe_div(x_sum));
283278
}
284279

285280
/// Normalizes (sum to 1 except 0) each row (dim=0) of a I64F64 matrix in-place.
@@ -289,7 +284,7 @@ pub fn inplace_row_normalize_64(x: &mut [Vec<I64F64>]) {
289284
let row_sum: I64F64 = row.iter().sum();
290285
if row_sum > I64F64::from_num(0.0_f64) {
291286
row.iter_mut()
292-
.for_each(|x_ij: &mut I64F64| *x_ij = x_ij.saturating_div(row_sum));
287+
.for_each(|x_ij: &mut I64F64| *x_ij = x_ij.safe_div(row_sum));
293288
}
294289
}
295290
}
@@ -302,7 +297,7 @@ pub fn vecdiv(x: &[I32F32], y: &[I32F32]) -> Vec<I32F32> {
302297
.zip(y)
303298
.map(|(x_i, y_i)| {
304299
if *y_i != 0 {
305-
x_i.saturating_div(*y_i)
300+
x_i.safe_div(*y_i)
306301
} else {
307302
I32F32::from_num(0)
308303
}
@@ -317,7 +312,7 @@ pub fn inplace_row_normalize(x: &mut [Vec<I32F32>]) {
317312
let row_sum: I32F32 = row.iter().sum();
318313
if row_sum > I32F32::from_num(0.0_f32) {
319314
row.iter_mut()
320-
.for_each(|x_ij: &mut I32F32| *x_ij = x_ij.saturating_div(row_sum));
315+
.for_each(|x_ij: &mut I32F32| *x_ij = x_ij.safe_div(row_sum));
321316
}
322317
}
323318
}
@@ -330,7 +325,7 @@ pub fn inplace_row_normalize_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>]) {
330325
if row_sum > I32F32::from_num(0.0) {
331326
sparse_row
332327
.iter_mut()
333-
.for_each(|(_j, value)| *value = value.saturating_div(row_sum));
328+
.for_each(|(_j, value)| *value = value.safe_div(row_sum));
334329
}
335330
}
336331
}
@@ -400,7 +395,7 @@ pub fn inplace_col_normalize_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>], co
400395
if col_sum[*j as usize] == I32F32::from_num(0.0_f32) {
401396
continue;
402397
}
403-
*value = value.saturating_div(col_sum[*j as usize]);
398+
*value = value.safe_div(col_sum[*j as usize]);
404399
}
405400
}
406401
}
@@ -428,7 +423,7 @@ pub fn inplace_col_normalize(x: &mut [Vec<I32F32>]) {
428423
.zip(&col_sums)
429424
.filter(|(_, col_sum)| **col_sum != I32F32::from_num(0_f32))
430425
.for_each(|(m_val, col_sum)| {
431-
*m_val = m_val.saturating_div(*col_sum);
426+
*m_val = m_val.safe_div(*col_sum);
432427
});
433428
});
434429
}
@@ -449,7 +444,7 @@ pub fn inplace_col_max_upscale_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>],
449444
if col_max[*j as usize] == I32F32::from_num(0.0_f32) {
450445
continue;
451446
}
452-
*value = value.saturating_div(col_max[*j as usize]);
447+
*value = value.safe_div(col_max[*j as usize]);
453448
}
454449
}
455450
}
@@ -477,7 +472,7 @@ pub fn inplace_col_max_upscale(x: &mut [Vec<I32F32>]) {
477472
.zip(&col_maxes)
478473
.filter(|(_, col_max)| **col_max != I32F32::from_num(0))
479474
.for_each(|(m_val, col_max)| {
480-
*m_val = m_val.saturating_div(*col_max);
475+
*m_val = m_val.safe_div(*col_max);
481476
});
482477
});
483478
}
@@ -898,7 +893,7 @@ pub fn weighted_median(
898893
return score[partition_idx[0]];
899894
}
900895
assert!(stake.len() == score.len());
901-
let mid_idx: usize = n.saturating_div(2);
896+
let mid_idx: usize = n.safe_div(2);
902897
let pivot: I32F32 = score[partition_idx[mid_idx]];
903898
let mut lo_stake: I32F32 = I32F32::from_num(0);
904899
let mut hi_stake: I32F32 = I32F32::from_num(0);
@@ -1411,3 +1406,48 @@ pub fn safe_ln(value: I32F32) -> I32F32 {
14111406
pub fn safe_exp(value: I32F32) -> I32F32 {
14121407
exp(value).unwrap_or(I32F32::from_num(0.0))
14131408
}
1409+
1410+
/// Safe division trait
1411+
pub trait SafeDiv {
1412+
/// Safe division that returns supplied default value for division by zero
1413+
fn safe_div_or(self, rhs: Self, def: Self) -> Self;
1414+
/// Safe division that returns default value for division by zero
1415+
fn safe_div(self, rhs: Self) -> Self;
1416+
}
1417+
1418+
/// Implementation of safe division trait for primitive types
1419+
macro_rules! impl_safe_div_for_primitive {
1420+
($($t:ty),*) => {
1421+
$(
1422+
impl SafeDiv for $t {
1423+
fn safe_div_or(self, rhs: Self, def: Self) -> Self {
1424+
self.checked_div(rhs).unwrap_or(def)
1425+
}
1426+
1427+
fn safe_div(self, rhs: Self) -> Self {
1428+
self.checked_div(rhs).unwrap_or_default()
1429+
}
1430+
}
1431+
)*
1432+
};
1433+
}
1434+
impl_safe_div_for_primitive!(u8, u16, u32, u64, i8, i16, i32, i64, usize);
1435+
1436+
/// Implementation of safe division trait for substrate fixed types
1437+
macro_rules! impl_safe_div_for_fixed {
1438+
($($t:ty),*) => {
1439+
$(
1440+
impl SafeDiv for $t {
1441+
fn safe_div_or(self, rhs: Self, def: Self) -> Self {
1442+
self.checked_div(rhs).unwrap_or(def)
1443+
}
1444+
1445+
fn safe_div(self, rhs: Self) -> Self {
1446+
self.checked_div(rhs).unwrap_or_default()
1447+
}
1448+
}
1449+
)*
1450+
};
1451+
}
1452+
1453+
impl_safe_div_for_fixed!(I96F32, I32F32, I64F64, I110F18, U64F64);

0 commit comments

Comments
 (0)