@@ -115,41 +115,101 @@ impl<'tcx> LateLintPass<'tcx> for DuplicateMutableAccounts {
115
115
}
116
116
}
117
117
}
118
+ } else {
119
+ // perform alternate constraint check, e.g., check fn bodies, then check key checks
120
+ self. check_fn ( )
121
+ }
122
+
123
+ // TODO: how to enforce that this is only called when necessary?
124
+ fn check_fn (
125
+ & mut self ,
126
+ cx : & LateContext < ' tcx > ,
127
+ _: FnKind < ' tcx > ,
128
+ _: & ' tcx FnDecl < ' tcx > ,
129
+ body : & ' tcx Body < ' tcx > ,
130
+ span : Span ,
131
+ _: HirId ,
132
+ ) {
133
+ if !span. from_expansion ( ) {
134
+ let accounts = get_referenced_accounts ( cx, body) ;
135
+
136
+ accounts. values ( ) . for_each ( |exprs| {
137
+ // TODO: figure out handling of >2 accounts
138
+ match exprs. len ( ) {
139
+ 2 => {
140
+ let first = exprs[ 0 ] ;
141
+ let second = exprs[ 1 ] ;
142
+ if !contains_key_call ( cx, body, first) {
143
+ span_lint_and_help (
144
+ cx,
145
+ DUP_MUTABLE_ACCOUNTS_2 ,
146
+ first. span ,
147
+ "this expression does not have a key check but has the same account type as another expression" ,
148
+ Some ( second. span ) ,
149
+ "add a key check to make sure the accounts have different keys, e.g., x.key() != y.key()" ,
150
+ ) ;
151
+ }
152
+ if !contains_key_call ( cx, body, second) {
153
+ span_lint_and_help (
154
+ cx,
155
+ DUP_MUTABLE_ACCOUNTS_2 ,
156
+ second. span ,
157
+ "this expression does not have a key check but has the same account type as another expression" ,
158
+ Some ( first. span ) ,
159
+ "add a key check to make sure the accounts have different keys, e.g., x.key() != y.key()" ,
160
+ ) ;
161
+ }
162
+ } ,
163
+ n if n > 2 => {
164
+ span_lint_and_note (
165
+ cx,
166
+ DUP_MUTABLE_ACCOUNTS_2 ,
167
+ exprs[ 0 ] . span ,
168
+ & format ! ( "the following expression has the same account type as {} other accounts" , exprs. len( ) ) ,
169
+ None ,
170
+ "might not check that each account has a unique key"
171
+ )
172
+ } ,
173
+ _ => { }
174
+ }
175
+ } ) ;
176
+ }
118
177
}
119
178
}
120
179
}
121
180
122
- /// Returns the `DefId` of the anchor account type, ie, `T` in `Account<'info, T>`.
123
- /// Returns `None` if the type of `field` is not an anchor account.
124
- fn get_anchor_account_type_def_id ( field : & FieldDef ) -> Option < DefId > {
125
- if_chain ! {
126
- if let TyKind :: Path ( qpath) = & field. ty. kind;
127
- if let QPath :: Resolved ( _, path) = qpath;
128
- if !path. segments. is_empty( ) ;
129
- if let Some ( generic_args) = path. segments[ 0 ] . args;
130
- if generic_args. args. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
131
- if let GenericArg :: Type ( hir_ty) = & generic_args. args[ 1 ] ;
132
- then {
133
- get_def_id( hir_ty)
134
- } else {
135
- None
181
+ mod anchor_constraint_check {
182
+ /// Returns the `DefId` of the anchor account type, ie, `T` in `Account<'info, T>`.
183
+ /// Returns `None` if the type of `field` is not an anchor account.
184
+ fn get_anchor_account_type_def_id ( field : & FieldDef ) -> Option < DefId > {
185
+ if_chain ! {
186
+ if let TyKind :: Path ( qpath) = & field. ty. kind;
187
+ if let QPath :: Resolved ( _, path) = qpath;
188
+ if !path. segments. is_empty( ) ;
189
+ if let Some ( generic_args) = path. segments[ 0 ] . args;
190
+ if generic_args. args. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
191
+ if let GenericArg :: Type ( hir_ty) = & generic_args. args[ 1 ] ;
192
+ then {
193
+ get_def_id( hir_ty)
194
+ } else {
195
+ None
196
+ }
136
197
}
137
198
}
138
- }
139
199
140
- /// Returns the `DefId` of `ty`, an hir type. Returns `None` if cannot resolve type.
141
- fn get_def_id ( ty : & rustc_hir:: Ty ) -> Option < DefId > {
142
- if_chain ! {
143
- if let TyKind :: Path ( qpath) = & ty. kind;
144
- if let QPath :: Resolved ( _, path) = qpath;
145
- if let Res :: Def ( _, def_id) = path. res;
146
- then {
147
- Some ( def_id)
148
- } else {
149
- None
200
+ /// Returns the `DefId` of `ty`, an hir type. Returns `None` if cannot resolve type.
201
+ fn get_def_id ( ty : & rustc_hir:: Ty ) -> Option < DefId > {
202
+ if_chain ! {
203
+ if let TyKind :: Path ( qpath) = & ty. kind;
204
+ if let QPath :: Resolved ( _, path) = qpath;
205
+ if let Res :: Def ( _, def_id) = path. res;
206
+ then {
207
+ Some ( def_id)
208
+ } else {
209
+ None
210
+ }
150
211
}
151
212
}
152
- }
153
213
154
214
/// Returns a `TokenStream` of form: `a`.key() != `b`.key().
155
215
fn create_key_check_constraint_tokenstream ( a : Symbol , b : Symbol ) -> TokenStream {
@@ -175,17 +235,27 @@ fn create_key_check_constraint_tokenstream(a: Symbol, b: Symbol) -> TokenStream
175
235
) ) ,
176
236
] ;
177
237
178
- TokenStream :: new ( constraint)
179
- }
238
+ TokenStream :: new ( constraint)
239
+ }
240
+
241
+ /// Returns a `TokenTree::Token` which has `TokenKind::Ident`, with the string set to `s`.
242
+ fn create_token_from_ident ( s : & str ) -> TokenTree {
243
+ let ident = Ident :: from_str ( s) ;
244
+ TokenTree :: Token ( Token :: from_ast_ident ( ident) )
245
+ }
246
+
247
+ #[ derive( Debug , Default ) ]
248
+ pub struct Streams ( Vec < TokenStream > ) ;
180
249
181
- /// Returns a `TokenTree::Token` which has `TokenKind::Ident`, with the string set to `s`.
182
- fn create_token_from_ident ( s : & str ) -> TokenTree {
183
- let ident = Ident :: from_str ( s) ;
184
- TokenTree :: Token ( Token :: from_ast_ident ( ident) )
250
+ impl Streams {
251
+ /// Returns true if `self` contains `other`, by comparing if there is an
252
+ /// identical `TokenStream` in `self` regardless of span.
253
+ fn contains ( & self , other : & TokenStream ) -> bool {
254
+ self . 0 . iter ( ) . any ( |stream| stream. eq_unspanned ( other) )
255
+ }
256
+ }
185
257
}
186
258
187
- #[ derive( Debug , Default ) ]
188
- pub struct Streams ( Vec < TokenStream > ) ;
189
259
190
260
impl Streams {
191
261
/// Returns true if `self` has a TokenStream that `other` is a substream of
@@ -222,6 +292,98 @@ impl Streams {
222
292
}
223
293
}
224
294
295
+ mod alternate_constraint_check {
296
+ struct AccountUses < ' cx , ' tcx > {
297
+ cx : & ' cx LateContext < ' tcx > ,
298
+ uses : HashMap < DefId , Vec < & ' tcx Expr < ' tcx > > > ,
299
+ }
300
+
301
+ fn get_referenced_accounts < ' tcx > (
302
+ cx : & LateContext < ' tcx > ,
303
+ body : & ' tcx Body < ' tcx > ,
304
+ ) -> HashMap < DefId , Vec < & ' tcx Expr < ' tcx > > > {
305
+ let mut accounts = AccountUses {
306
+ cx,
307
+ uses : HashMap :: new ( ) ,
308
+ } ;
309
+
310
+ accounts. visit_expr ( & body. value ) ;
311
+ accounts. uses
312
+ }
313
+
314
+ impl < ' cx , ' tcx > Visitor < ' tcx > for AccountUses < ' cx , ' tcx > {
315
+ fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
316
+ if_chain ! {
317
+ // get mutable reference expressions
318
+ if let ExprKind :: AddrOf ( _, mutability, mut_expr) = expr. kind;
319
+ if let Mutability :: Mut = mutability;
320
+ // check type of expr == Account<'info, T>
321
+ let middle_ty = self . cx. typeck_results( ) . expr_ty( mut_expr) ;
322
+ if match_type( self . cx, middle_ty, & paths:: ANCHOR_ACCOUNT ) ;
323
+ // grab T generic parameter
324
+ if let TyKind :: Adt ( _adt_def, substs) = middle_ty. kind( ) ;
325
+ if substs. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
326
+ let account_type = substs[ 1 ] . expect_ty( ) ; // TODO: could just store middle::Ty instead of DefId?
327
+ if let Some ( adt_def) = account_type. ty_adt_def( ) ;
328
+ then {
329
+ let def_id = adt_def. did( ) ;
330
+ if let Some ( exprs) = self . uses. get_mut( & def_id) {
331
+ let mut spanless_eq = SpanlessEq :: new( self . cx) ;
332
+ // check that expr is not a duplicate within its particular key-pair
333
+ if exprs. iter( ) . all( |e| !spanless_eq. eq_expr( e, mut_expr) ) {
334
+ exprs. push( mut_expr) ;
335
+ }
336
+ } else {
337
+ self . uses. insert( def_id, vec![ mut_expr] ) ;
338
+ }
339
+ }
340
+ }
341
+ walk_expr ( self , expr) ;
342
+ }
343
+ }
344
+
345
+ /// Performs a walk on `body`, checking whether there exists an expression that contains
346
+ /// a `key()` method call on `account_expr`.
347
+ fn contains_key_call < ' tcx > (
348
+ cx : & LateContext < ' tcx > ,
349
+ body : & ' tcx Body < ' tcx > ,
350
+ account_expr : & Expr < ' tcx > ,
351
+ ) -> bool {
352
+ visit_expr_no_bodies ( & body. value , |expr| {
353
+ if_chain ! {
354
+ if let ExprKind :: MethodCall ( path_seg, exprs, _span) = expr. kind;
355
+ if path_seg. ident. name. as_str( ) == "key" ;
356
+ if !exprs. is_empty( ) ;
357
+ let mut spanless_eq = SpanlessEq :: new( cx) ;
358
+ if spanless_eq. eq_expr( & exprs[ 0 ] , account_expr) ;
359
+ then {
360
+ true
361
+ } else {
362
+ false
363
+ }
364
+ }
365
+ } )
366
+ }
367
+ }
368
+
369
+ // /// Splits `stream` into a vector of substreams, separated by `delimiter`.
370
+ // fn split(stream: CursorRef, delimiter: TokenKind) -> Vec<TokenStream> {
371
+ // let mut split_streams: Vec<TokenStream> = Vec::new();
372
+ // let mut temp: Vec<TreeAndSpacing> = Vec::new();
373
+ // let delim = TokenTree::Token(Token::new(delimiter, DUMMY_SP));
374
+
375
+ // stream.for_each(|t| {
376
+ // if t.eq_unspanned(&delim) {
377
+ // split_streams.push(TokenStream::new(temp.clone()));
378
+ // temp.clear();
379
+ // } else {
380
+ // temp.push(TreeAndSpacing::from(t.clone()));
381
+ // }
382
+ // });
383
+ // split_streams.push(TokenStream::new(temp));
384
+ // split_streams
385
+ // }
386
+
225
387
#[ test]
226
388
fn insecure ( ) {
227
389
dylint_testing:: ui_test_example ( env ! ( "CARGO_PKG_NAME" ) , "insecure" ) ;
0 commit comments