Skip to content

Commit 5ca8fa5

Browse files
authored
Merge pull request #1347 from opentensor/forbid-saturating-math-lint
Ban saturating arithmetic in tests
2 parents b7f2640 + e750ca4 commit 5ca8fa5

File tree

4 files changed

+125
-5
lines changed

4 files changed

+125
-5
lines changed

build.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fn main() {
3131

3232
// Parse each rust file with syn and run the linting suite on it in parallel
3333
rust_files.par_iter().for_each_with(tx.clone(), |tx, file| {
34+
let is_test = file.display().to_string().contains("test");
3435
let Ok(content) = fs::read_to_string(file) else {
3536
return;
3637
};
@@ -63,6 +64,10 @@ fn main() {
6364
track_lint(ForbidKeysRemoveCall::lint(&parsed_file));
6465
track_lint(RequireFreezeStruct::lint(&parsed_file));
6566
track_lint(RequireExplicitPalletIndex::lint(&parsed_file));
67+
68+
if is_test {
69+
track_lint(ForbidSaturatingMath::lint(&parsed_file));
70+
}
6671
});
6772

6873
// Collect and print all errors after the parallel processing is done

pallets/subtensor/src/tests/epoch.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,19 +2285,19 @@ fn test_compute_alpha_values() {
22852285
// exp_val = exp(0.0 - 1.0 * 0.1) = exp(-0.1)
22862286
// alpha[0] = 1 / (1 + exp(-0.1)) ~ 0.9048374180359595
22872287
let exp_val_0 = I32F32::from_num(0.9048374180359595);
2288-
let expected_alpha_0 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_0);
2288+
let expected_alpha_0 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_0);
22892289

22902290
// For consensus[1] = 0.5:
22912291
// exp_val = exp(0.0 - 1.0 * 0.5) = exp(-0.5)
22922292
// alpha[1] = 1 / (1 + exp(-0.5)) ~ 0.6065306597126334
22932293
let exp_val_1 = I32F32::from_num(0.6065306597126334);
2294-
let expected_alpha_1 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_1);
2294+
let expected_alpha_1 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_1);
22952295

22962296
// For consensus[2] = 0.9:
22972297
// exp_val = exp(0.0 - 1.0 * 0.9) = exp(-0.9)
22982298
// alpha[2] = 1 / (1 + exp(-0.9)) ~ 0.4065696597405991
22992299
let exp_val_2 = I32F32::from_num(0.4065696597405991);
2300-
let expected_alpha_2 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_2);
2300+
let expected_alpha_2 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_2);
23012301

23022302
// Define an epsilon for approximate equality checks.
23032303
let epsilon = I32F32::from_num(1e-6);
@@ -2329,13 +2329,13 @@ fn test_compute_alpha_values_256_miners() {
23292329

23302330
for (i, &c) in consensus.iter().enumerate() {
23312331
// Use saturating subtraction and multiplication
2332-
let exponent = b.saturating_sub(a.saturating_mul(c));
2332+
let exponent = b - (a * c);
23332333

23342334
// Use safe_exp instead of exp
23352335
let exp_val = safe_exp(exponent);
23362336

23372337
// Use saturating addition and division
2338-
let expected_alpha = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val);
2338+
let expected_alpha = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val);
23392339

23402340
// Assert that the computed alpha values match the expected values within the epsilon.
23412341
assert_approx_eq(alpha[i], expected_alpha, epsilon);
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use super::*;
2+
use syn::{Expr, ExprCall, ExprMethodCall, ExprPath, File, Path, spanned::Spanned, visit::Visit};
3+
4+
pub struct ForbidSaturatingMath;
5+
6+
impl Lint for ForbidSaturatingMath {
7+
fn lint(source: &File) -> Result {
8+
let mut visitor = SaturatingMathBanVisitor::default();
9+
visitor.visit_file(source);
10+
11+
if visitor.errors.is_empty() {
12+
Ok(())
13+
} else {
14+
Err(visitor.errors)
15+
}
16+
}
17+
}
18+
19+
#[derive(Default)]
20+
struct SaturatingMathBanVisitor {
21+
errors: Vec<syn::Error>,
22+
}
23+
24+
impl<'ast> Visit<'ast> for SaturatingMathBanVisitor {
25+
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
26+
let ExprMethodCall { method, .. } = node;
27+
28+
if method.to_string().starts_with("saturating_") {
29+
let msg = "Safe math is banned to encourage tests to panic";
30+
self.errors.push(syn::Error::new(method.span(), msg));
31+
}
32+
}
33+
34+
fn visit_expr_call(&mut self, node: &'ast ExprCall) {
35+
let ExprCall { func, .. } = node;
36+
37+
if is_saturating_math_call(func) {
38+
let msg = "Safe math is banned to encourage tests to panic";
39+
self.errors.push(syn::Error::new(node.func.span(), msg));
40+
}
41+
}
42+
}
43+
44+
fn is_saturating_math_call(func: &Expr) -> bool {
45+
let Expr::Path(ExprPath {
46+
path: Path { segments: path, .. },
47+
..
48+
}) = func
49+
else {
50+
return false;
51+
};
52+
53+
path.last()
54+
.is_some_and(|seg| seg.ident.to_string().starts_with("saturating_"))
55+
}
56+
57+
#[cfg(test)]
58+
mod tests {
59+
use super::*;
60+
use quote::quote;
61+
62+
fn lint(input: proc_macro2::TokenStream) -> Result {
63+
let mut visitor = SaturatingMathBanVisitor::default();
64+
let expr: syn::Expr = syn::parse2(input).expect("should be a valid expression");
65+
66+
match &expr {
67+
syn::Expr::MethodCall(call) => visitor.visit_expr_method_call(call),
68+
syn::Expr::Call(call) => visitor.visit_expr_call(call),
69+
_ => panic!("should be a valid method call or function call"),
70+
}
71+
72+
if visitor.errors.is_empty() {
73+
Ok(())
74+
} else {
75+
Err(visitor.errors)
76+
}
77+
}
78+
79+
#[test]
80+
fn test_saturating_forbidden() {
81+
let input = quote! { stake.saturating_add(alpha) };
82+
assert!(lint(input).is_err());
83+
let input = quote! { alpha_price.saturating_mul(float_alpha_block_emission) };
84+
assert!(lint(input).is_err());
85+
let input = quote! { alpha_out_i.saturating_sub(root_alpha) };
86+
assert!(lint(input).is_err());
87+
}
88+
89+
#[test]
90+
fn test_saturating_ufcs_forbidden() {
91+
let input = quote! { SaturatingAdd::saturating_add(stake, alpha) };
92+
assert!(lint(input).is_err());
93+
let input = quote! { core::num::SaturatingAdd::saturating_add(stake, alpha) };
94+
assert!(lint(input).is_err());
95+
let input =
96+
quote! { SaturatingMul::saturating_mul(alpha_price, float_alpha_block_emission) };
97+
assert!(lint(input).is_err());
98+
let input = quote! { core::num::SaturatingMul::saturating_mul(alpha_price, float_alpha_block_emission) };
99+
assert!(lint(input).is_err());
100+
let input = quote! { SaturatingSub::saturating_sub(alpha_out_i, root_alpha) };
101+
assert!(lint(input).is_err());
102+
let input = quote! { core::num::SaturatingSub::saturating_sub(alpha_out_i, root_alpha) };
103+
assert!(lint(input).is_err());
104+
}
105+
106+
#[test]
107+
fn test_saturating_to_from_num_forbidden() {
108+
let input = quote! { I96F32::saturating_from_num(u64::MAX) };
109+
assert!(lint(input).is_err());
110+
let input = quote! { remaining_emission.saturating_to_num::<u64>() };
111+
assert!(lint(input).is_err());
112+
}
113+
}

support/linting/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ pub use lint::*;
33

44
mod forbid_as_primitive;
55
mod forbid_keys_remove;
6+
mod forbid_saturating_math;
67
mod pallet_index;
78
mod require_freeze_struct;
89

910
pub use forbid_as_primitive::ForbidAsPrimitiveConversion;
1011
pub use forbid_keys_remove::ForbidKeysRemoveCall;
12+
pub use forbid_saturating_math::ForbidSaturatingMath;
1113
pub use pallet_index::RequireExplicitPalletIndex;
1214
pub use require_freeze_struct::RequireFreezeStruct;

0 commit comments

Comments
 (0)