1
1
use super :: * ;
2
2
use syn:: {
3
3
punctuated:: Punctuated , spanned:: Spanned , token:: Comma , visit:: Visit , Expr , ExprCall , ExprPath ,
4
- File ,
4
+ File , Path ,
5
5
} ;
6
6
7
7
pub struct ForbidKeysRemoveCall ;
@@ -26,25 +26,30 @@ struct KeysRemoveVisitor {
26
26
27
27
impl < ' ast > Visit < ' ast > for KeysRemoveVisitor {
28
28
fn visit_expr_call ( & mut self , node : & ' ast syn:: ExprCall ) {
29
- let ExprCall { func, args, .. } = node;
30
- if is_keys_remove_call ( func, args) {
29
+ let ExprCall {
30
+ func, args, attrs, ..
31
+ } = node;
32
+
33
+ if is_keys_remove_call ( func, args) && !is_allowed ( attrs) {
31
34
let msg = "Keys::<T>::remove()` is banned to prevent accidentally breaking \
32
- the neuron sequence. If you need to replace neuron , try `SubtensorModule::replace_neuron()`";
35
+ the neuron sequence. If you need to replace neurons , try `SubtensorModule::replace_neuron()`";
33
36
self . errors . push ( syn:: Error :: new ( node. func . span ( ) , msg) ) ;
34
37
}
35
38
}
36
39
}
37
40
38
41
fn is_keys_remove_call ( func : & Expr , args : & Punctuated < Expr , Comma > ) -> bool {
39
- let Expr :: Path ( ExprPath { path, .. } ) = func else {
42
+ let Expr :: Path ( ExprPath {
43
+ path : Path { segments : func, .. } ,
44
+ ..
45
+ } ) = func
46
+ else {
40
47
return false ;
41
48
} ;
42
- let func = & path. segments ;
43
- if func. len ( ) != 2 || args. len ( ) != 2 {
44
- return false ;
45
- }
46
49
47
- func[ 0 ] . ident == "Keys"
50
+ func. len ( ) == 2
51
+ && args. len ( ) == 2
52
+ && func[ 0 ] . ident == "Keys"
48
53
&& !func[ 0 ] . arguments . is_none ( )
49
54
&& func[ 1 ] . ident == "remove"
50
55
&& func[ 1 ] . arguments . is_none ( )
@@ -91,4 +96,23 @@ mod tests {
91
96
let input = r#"ChildKeys::<T>::remove(netuid, uid_to_replace)"# ;
92
97
assert ! ( lint( input) . is_ok( ) ) ;
93
98
}
99
+
100
+ #[ test]
101
+ fn test_keys_remove_allowed ( ) {
102
+ let input = r#"
103
+ #[allow(unknown_lints)]
104
+ Keys::<T>::remove(netuid, uid_to_replace)
105
+ "# ;
106
+ assert ! ( lint( input) . is_ok( ) ) ;
107
+ let input = r#"
108
+ #[allow(unknown_lints)]
109
+ Keys::<U>::remove(netuid, uid_to_replace)
110
+ "# ;
111
+ assert ! ( lint( input) . is_ok( ) ) ;
112
+ let input = r#"
113
+ #[allow(unknown_lints)]
114
+ Keys::<U>::remove(1, "2".parse().unwrap(),)
115
+ "# ;
116
+ assert ! ( lint( input) . is_ok( ) ) ;
117
+ }
94
118
}
0 commit comments