@@ -8,7 +8,7 @@ use sp_std::cmp::Ordering;
8
8
9
9
use sp_std:: vec;
10
10
use substrate_fixed:: transcendental:: { exp, ln} ;
11
- use substrate_fixed:: types:: { I32F32 , I64F64 } ;
11
+ use substrate_fixed:: types:: { I110F18 , I32F32 , I64F64 , I96F32 , U64F64 } ;
12
12
13
13
// TODO: figure out what cfg gate this needs to not be a warning in rustc
14
14
#[ allow( unused) ]
@@ -51,7 +51,7 @@ pub fn u16_to_fixed(x: u16) -> I32F32 {
51
51
52
52
#[ allow( dead_code) ]
53
53
pub fn u16_proportion_to_fixed ( x : u16 ) -> I32F32 {
54
- I32F32 :: from_num ( x) . saturating_div ( I32F32 :: from_num ( u16:: MAX ) )
54
+ I32F32 :: from_num ( x) . safe_div ( I32F32 :: from_num ( u16:: MAX ) )
55
55
}
56
56
57
57
#[ allow( dead_code) ]
@@ -107,7 +107,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
107
107
return vec
108
108
. iter ( )
109
109
. map ( |e : & I32F32 | {
110
- e. saturating_mul ( u16_max. saturating_div ( * val) )
110
+ e. saturating_mul ( u16_max. safe_div ( * val) )
111
111
. round ( )
112
112
. to_num :: < u16 > ( )
113
113
} )
@@ -116,7 +116,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
116
116
vec. iter ( )
117
117
. map ( |e : & I32F32 | {
118
118
e. saturating_mul ( u16_max)
119
- . saturating_div ( * val)
119
+ . safe_div ( * val)
120
120
. round ( )
121
121
. to_num :: < u16 > ( )
122
122
} )
@@ -125,11 +125,7 @@ pub fn vec_max_upscale_to_u16(vec: &[I32F32]) -> Vec<u16> {
125
125
None => {
126
126
let sum: I32F32 = vec. iter ( ) . sum ( ) ;
127
127
vec. iter ( )
128
- . map ( |e : & I32F32 | {
129
- e. saturating_mul ( u16_max)
130
- . saturating_div ( sum)
131
- . to_num :: < u16 > ( )
132
- } )
128
+ . map ( |e : & I32F32 | e. saturating_mul ( u16_max) . safe_div ( sum) . to_num :: < u16 > ( ) )
133
129
. collect ( )
134
130
}
135
131
}
@@ -145,8 +141,7 @@ pub fn vec_u16_max_upscale_to_u16(vec: &[u16]) -> Vec<u16> {
145
141
#[ allow( dead_code) ]
146
142
// Checks if u16 vector, when normalized, has a max value not greater than a u16 ratio max_limit.
147
143
pub fn check_vec_max_limited ( vec : & [ u16 ] , max_limit : u16 ) -> bool {
148
- let max_limit_fixed: I32F32 =
149
- I32F32 :: from_num ( max_limit) . saturating_div ( I32F32 :: from_num ( u16:: MAX ) ) ;
144
+ let max_limit_fixed: I32F32 = I32F32 :: from_num ( max_limit) . safe_div ( I32F32 :: from_num ( u16:: MAX ) ) ;
150
145
let mut vec_fixed: Vec < I32F32 > = vec. iter ( ) . map ( |e : & u16 | I32F32 :: from_num ( * e) ) . collect ( ) ;
151
146
inplace_normalize ( & mut vec_fixed) ;
152
147
let max_value: Option < & I32F32 > = vec_fixed. iter ( ) . max ( ) ;
@@ -219,7 +214,7 @@ pub fn sigmoid_safe(input: I32F32, rho: I32F32, kappa: I32F32) -> I32F32 {
219
214
let exp_input: I32F32 = neg_rho. saturating_mul ( offset) ; // -rho*(input-kappa)
220
215
let exp_output: I32F32 = exp_safe ( exp_input) ; // exp(-rho*(input-kappa))
221
216
let denominator: I32F32 = exp_output. saturating_add ( one) ; // 1 + exp(-rho*(input-kappa))
222
- let sigmoid_output: I32F32 = one. saturating_div ( denominator) ; // 1 / (1 + exp(-rho*(input-kappa)))
217
+ let sigmoid_output: I32F32 = one. safe_div ( denominator) ; // 1 / (1 + exp(-rho*(input-kappa)))
223
218
sigmoid_output
224
219
}
225
220
@@ -244,7 +239,7 @@ pub fn is_topk(vector: &[I32F32], k: usize) -> Vec<bool> {
244
239
pub fn normalize ( x : & [ I32F32 ] ) -> Vec < I32F32 > {
245
240
let x_sum: I32F32 = sum ( x) ;
246
241
if x_sum != I32F32 :: from_num ( 0.0_f32 ) {
247
- x. iter ( ) . map ( |xi| xi. saturating_div ( x_sum) ) . collect ( )
242
+ x. iter ( ) . map ( |xi| xi. safe_div ( x_sum) ) . collect ( )
248
243
} else {
249
244
x. to_vec ( )
250
245
}
@@ -258,7 +253,7 @@ pub fn inplace_normalize(x: &mut [I32F32]) {
258
253
return ;
259
254
}
260
255
x. iter_mut ( )
261
- . for_each ( |value| * value = value. saturating_div ( x_sum) ) ;
256
+ . for_each ( |value| * value = value. safe_div ( x_sum) ) ;
262
257
}
263
258
264
259
// Normalizes (sum to 1 except 0) the input vector directly in-place, using the sum arg.
@@ -268,7 +263,7 @@ pub fn inplace_normalize_using_sum(x: &mut [I32F32], x_sum: I32F32) {
268
263
return ;
269
264
}
270
265
x. iter_mut ( )
271
- . for_each ( |value| * value = value. saturating_div ( x_sum) ) ;
266
+ . for_each ( |value| * value = value. safe_div ( x_sum) ) ;
272
267
}
273
268
274
269
// Normalizes (sum to 1 except 0) the I64F64 input vector directly in-place.
@@ -279,7 +274,7 @@ pub fn inplace_normalize_64(x: &mut [I64F64]) {
279
274
return ;
280
275
}
281
276
x. iter_mut ( )
282
- . for_each ( |value| * value = value. saturating_div ( x_sum) ) ;
277
+ . for_each ( |value| * value = value. safe_div ( x_sum) ) ;
283
278
}
284
279
285
280
/// Normalizes (sum to 1 except 0) each row (dim=0) of a I64F64 matrix in-place.
@@ -289,7 +284,7 @@ pub fn inplace_row_normalize_64(x: &mut [Vec<I64F64>]) {
289
284
let row_sum: I64F64 = row. iter ( ) . sum ( ) ;
290
285
if row_sum > I64F64 :: from_num ( 0.0_f64 ) {
291
286
row. iter_mut ( )
292
- . for_each ( |x_ij : & mut I64F64 | * x_ij = x_ij. saturating_div ( row_sum) ) ;
287
+ . for_each ( |x_ij : & mut I64F64 | * x_ij = x_ij. safe_div ( row_sum) ) ;
293
288
}
294
289
}
295
290
}
@@ -302,7 +297,7 @@ pub fn vecdiv(x: &[I32F32], y: &[I32F32]) -> Vec<I32F32> {
302
297
. zip ( y)
303
298
. map ( |( x_i, y_i) | {
304
299
if * y_i != 0 {
305
- x_i. saturating_div ( * y_i)
300
+ x_i. safe_div ( * y_i)
306
301
} else {
307
302
I32F32 :: from_num ( 0 )
308
303
}
@@ -317,7 +312,7 @@ pub fn inplace_row_normalize(x: &mut [Vec<I32F32>]) {
317
312
let row_sum: I32F32 = row. iter ( ) . sum ( ) ;
318
313
if row_sum > I32F32 :: from_num ( 0.0_f32 ) {
319
314
row. iter_mut ( )
320
- . for_each ( |x_ij : & mut I32F32 | * x_ij = x_ij. saturating_div ( row_sum) ) ;
315
+ . for_each ( |x_ij : & mut I32F32 | * x_ij = x_ij. safe_div ( row_sum) ) ;
321
316
}
322
317
}
323
318
}
@@ -330,7 +325,7 @@ pub fn inplace_row_normalize_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>]) {
330
325
if row_sum > I32F32 :: from_num ( 0.0 ) {
331
326
sparse_row
332
327
. iter_mut ( )
333
- . for_each ( |( _j, value) | * value = value. saturating_div ( row_sum) ) ;
328
+ . for_each ( |( _j, value) | * value = value. safe_div ( row_sum) ) ;
334
329
}
335
330
}
336
331
}
@@ -400,7 +395,7 @@ pub fn inplace_col_normalize_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>], co
400
395
if col_sum[ * j as usize ] == I32F32 :: from_num ( 0.0_f32 ) {
401
396
continue ;
402
397
}
403
- * value = value. saturating_div ( col_sum[ * j as usize ] ) ;
398
+ * value = value. safe_div ( col_sum[ * j as usize ] ) ;
404
399
}
405
400
}
406
401
}
@@ -428,7 +423,7 @@ pub fn inplace_col_normalize(x: &mut [Vec<I32F32>]) {
428
423
. zip ( & col_sums)
429
424
. filter ( |( _, col_sum) | * * col_sum != I32F32 :: from_num ( 0_f32 ) )
430
425
. for_each ( |( m_val, col_sum) | {
431
- * m_val = m_val. saturating_div ( * col_sum) ;
426
+ * m_val = m_val. safe_div ( * col_sum) ;
432
427
} ) ;
433
428
} ) ;
434
429
}
@@ -449,7 +444,7 @@ pub fn inplace_col_max_upscale_sparse(sparse_matrix: &mut [Vec<(u16, I32F32)>],
449
444
if col_max[ * j as usize ] == I32F32 :: from_num ( 0.0_f32 ) {
450
445
continue ;
451
446
}
452
- * value = value. saturating_div ( col_max[ * j as usize ] ) ;
447
+ * value = value. safe_div ( col_max[ * j as usize ] ) ;
453
448
}
454
449
}
455
450
}
@@ -477,7 +472,7 @@ pub fn inplace_col_max_upscale(x: &mut [Vec<I32F32>]) {
477
472
. zip ( & col_maxes)
478
473
. filter ( |( _, col_max) | * * col_max != I32F32 :: from_num ( 0 ) )
479
474
. for_each ( |( m_val, col_max) | {
480
- * m_val = m_val. saturating_div ( * col_max) ;
475
+ * m_val = m_val. safe_div ( * col_max) ;
481
476
} ) ;
482
477
} ) ;
483
478
}
@@ -898,7 +893,7 @@ pub fn weighted_median(
898
893
return score[ partition_idx[ 0 ] ] ;
899
894
}
900
895
assert ! ( stake. len( ) == score. len( ) ) ;
901
- let mid_idx: usize = n. saturating_div ( 2 ) ;
896
+ let mid_idx: usize = n. safe_div ( 2 ) ;
902
897
let pivot: I32F32 = score[ partition_idx[ mid_idx] ] ;
903
898
let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
904
899
let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
@@ -1411,3 +1406,48 @@ pub fn safe_ln(value: I32F32) -> I32F32 {
1411
1406
pub fn safe_exp ( value : I32F32 ) -> I32F32 {
1412
1407
exp ( value) . unwrap_or ( I32F32 :: from_num ( 0.0 ) )
1413
1408
}
1409
+
1410
+ /// Safe division trait
1411
+ pub trait SafeDiv {
1412
+ /// Safe division that returns supplied default value for division by zero
1413
+ fn safe_div_or ( self , rhs : Self , def : Self ) -> Self ;
1414
+ /// Safe division that returns default value for division by zero
1415
+ fn safe_div ( self , rhs : Self ) -> Self ;
1416
+ }
1417
+
1418
+ /// Implementation of safe division trait for primitive types
1419
+ macro_rules! impl_safe_div_for_primitive {
1420
+ ( $( $t: ty) ,* ) => {
1421
+ $(
1422
+ impl SafeDiv for $t {
1423
+ fn safe_div_or( self , rhs: Self , def: Self ) -> Self {
1424
+ self . checked_div( rhs) . unwrap_or( def)
1425
+ }
1426
+
1427
+ fn safe_div( self , rhs: Self ) -> Self {
1428
+ self . checked_div( rhs) . unwrap_or_default( )
1429
+ }
1430
+ }
1431
+ ) *
1432
+ } ;
1433
+ }
1434
+ impl_safe_div_for_primitive ! ( u8 , u16 , u32 , u64 , i8 , i16 , i32 , i64 , usize ) ;
1435
+
1436
+ /// Implementation of safe division trait for substrate fixed types
1437
+ macro_rules! impl_safe_div_for_fixed {
1438
+ ( $( $t: ty) ,* ) => {
1439
+ $(
1440
+ impl SafeDiv for $t {
1441
+ fn safe_div_or( self , rhs: Self , def: Self ) -> Self {
1442
+ self . checked_div( rhs) . unwrap_or( def)
1443
+ }
1444
+
1445
+ fn safe_div( self , rhs: Self ) -> Self {
1446
+ self . checked_div( rhs) . unwrap_or_default( )
1447
+ }
1448
+ }
1449
+ ) *
1450
+ } ;
1451
+ }
1452
+
1453
+ impl_safe_div_for_fixed ! ( I96F32 , I32F32 , I64F64 , I110F18 , U64F64 ) ;
0 commit comments