@@ -204,6 +204,8 @@ struct Builder {
204204
205205 // parameters used in the datalog source
206206 pub datalog_parameters : HashSet < String > ,
207+ // scope parameters used in the datalog source
208+ pub datalog_scope_parameters : HashSet < String > ,
207209 // parameters provided to the macro
208210 pub macro_parameters : HashSet < String > ,
209211
@@ -227,6 +229,7 @@ impl Builder {
227229 parameters,
228230
229231 datalog_parameters : HashSet :: new ( ) ,
232+ datalog_scope_parameters : HashSet :: new ( ) ,
230233 macro_parameters,
231234
232235 facts : Vec :: new ( ) ,
@@ -286,7 +289,8 @@ impl Builder {
286289 }
287290
288291 if let Some ( parameters) = & rule. scope_parameters {
289- self . datalog_parameters . extend ( parameters. keys ( ) . cloned ( ) ) ;
292+ self . datalog_scope_parameters
293+ . extend ( parameters. keys ( ) . cloned ( ) ) ;
290294 }
291295 }
292296
@@ -316,12 +320,17 @@ impl Builder {
316320 }
317321
318322 fn validate ( & self ) -> Result < ( ) , error:: LanguageError > {
319- if self . macro_parameters . is_subset ( & self . datalog_parameters ) {
323+ let all_parameters = self
324+ . datalog_parameters
325+ . union ( & self . datalog_scope_parameters )
326+ . cloned ( )
327+ . collect ( ) ;
328+ if self . macro_parameters . is_subset ( & all_parameters) {
320329 Ok ( ( ) )
321330 } else {
322331 let unused_parameters: Vec < String > = self
323332 . macro_parameters
324- . difference ( & self . datalog_parameters )
333+ . difference ( & all_parameters )
325334 . cloned ( )
326335 . collect ( ) ;
327336 Err ( error:: LanguageError :: Parameters {
@@ -334,6 +343,7 @@ impl Builder {
334343
335344struct Item {
336345 parameters : HashSet < String > ,
346+ scope_parameters : HashSet < String > ,
337347 start : TokenStream ,
338348 middle : TokenStream ,
339349 end : TokenStream ,
@@ -348,6 +358,7 @@ impl Item {
348358 . flatten ( )
349359 . map ( |( name, _) | name. to_owned ( ) )
350360 . collect ( ) ,
361+ scope_parameters : HashSet :: new ( ) ,
351362 start : quote ! {
352363 let mut __biscuit_auth_item = #fact;
353364 } ,
@@ -360,6 +371,7 @@ impl Item {
360371 fn rule ( rule : & Rule ) -> Self {
361372 Self {
362373 parameters : Item :: rule_params ( rule) . collect ( ) ,
374+ scope_parameters : Item :: rule_scope_params ( rule) . collect ( ) ,
363375 start : quote ! {
364376 let mut __biscuit_auth_item = #rule;
365377 } ,
@@ -373,6 +385,11 @@ impl Item {
373385 fn check ( check : & Check ) -> Self {
374386 Self {
375387 parameters : check. queries . iter ( ) . flat_map ( Item :: rule_params) . collect ( ) ,
388+ scope_parameters : check
389+ . queries
390+ . iter ( )
391+ . flat_map ( Item :: rule_scope_params)
392+ . collect ( ) ,
376393 start : quote ! {
377394 let mut __biscuit_auth_item = #check;
378395 } ,
@@ -386,6 +403,11 @@ impl Item {
386403 fn policy ( policy : & Policy ) -> Self {
387404 Self {
388405 parameters : policy. queries . iter ( ) . flat_map ( Item :: rule_params) . collect ( ) ,
406+ scope_parameters : policy
407+ . queries
408+ . iter ( )
409+ . flat_map ( Item :: rule_scope_params)
410+ . collect ( ) ,
389411 start : quote ! {
390412 let mut __biscuit_auth_item = #policy;
391413 } ,
@@ -400,19 +422,22 @@ impl Item {
400422 rule. parameters
401423 . iter ( )
402424 . flatten ( )
403- . map ( |( name, _) | name. as_ref ( ) )
404- . chain (
405- rule . scope_parameters
406- . iter ( )
407- . flatten ( )
408- . map ( | ( name , _ ) | name . as_ref ( ) ) ,
409- )
410- . map ( str :: to_owned )
425+ . map ( |( name, _) | name. to_string ( ) )
426+ }
427+
428+ fn rule_scope_params ( rule : & Rule ) -> impl Iterator < Item = String > + ' _ {
429+ rule . scope_parameters
430+ . iter ( )
431+ . flatten ( )
432+ . map ( | ( name , _ ) | name . to_string ( ) )
411433 }
412434
413435 fn needs_param ( & self , name : & str ) -> bool {
414436 self . parameters . contains ( name)
415437 }
438+ fn needs_scope_param ( & self , name : & str ) -> bool {
439+ self . scope_parameters . contains ( name)
440+ }
416441
417442 fn add_param ( & mut self , name : & str , clone : bool ) {
418443 let ident = Ident :: new ( name, Span :: call_site ( ) ) ;
@@ -427,6 +452,20 @@ impl Item {
427452 __biscuit_auth_item. set_macro_param( #name, #expr) . unwrap( ) ;
428453 } ) ;
429454 }
455+
456+ fn add_scope_param ( & mut self , name : & str , clone : bool ) {
457+ let ident = Ident :: new ( name, Span :: call_site ( ) ) ;
458+
459+ let expr = if clone {
460+ quote ! { :: core:: clone:: Clone :: clone( & #ident) }
461+ } else {
462+ quote ! { #ident }
463+ } ;
464+
465+ self . middle . extend ( quote ! {
466+ __biscuit_auth_item. set_macro_scope_param( #name, #expr) . unwrap( ) ;
467+ } ) ;
468+ }
430469}
431470
432471impl ToTokens for Item {
@@ -477,6 +516,21 @@ impl ToTokens for Builder {
477516 }
478517 }
479518
519+ for param in & self . datalog_scope_parameters {
520+ let mut items = items
521+ . iter_mut ( )
522+ . filter ( |i| i. needs_scope_param ( param) )
523+ . peekable ( ) ;
524+
525+ loop {
526+ match ( items. next ( ) , items. peek ( ) ) {
527+ ( Some ( cur) , Some ( _next) ) => cur. add_scope_param ( param, true ) ,
528+ ( Some ( cur) , None ) => cur. add_scope_param ( param, false ) ,
529+ ( None , _) => break ,
530+ }
531+ }
532+ }
533+
480534 let builder_type = & self . builder_type ;
481535 let builder_quote = if let Some ( target) = & self . target {
482536 quote ! {
@@ -558,6 +612,12 @@ pub fn rule(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
558612 }
559613 }
560614
615+ for param in & builder. datalog_scope_parameters {
616+ if rule_item. needs_scope_param ( param) {
617+ rule_item. add_scope_param ( param, false ) ;
618+ }
619+ }
620+
561621 ( quote ! {
562622 {
563623 #params_quote
@@ -694,6 +754,12 @@ pub fn check(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
694754 }
695755 }
696756
757+ for param in & builder. datalog_scope_parameters {
758+ if check_item. needs_scope_param ( param) {
759+ check_item. add_scope_param ( param, false ) ;
760+ }
761+ }
762+
697763 ( quote ! {
698764 {
699765 #params_quote
@@ -766,6 +832,12 @@ pub fn policy(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
766832 }
767833 }
768834
835+ for param in & builder. datalog_scope_parameters {
836+ if policy_item. needs_scope_param ( param) {
837+ policy_item. add_scope_param ( param, false ) ;
838+ }
839+ }
840+
769841 ( quote ! {
770842 {
771843 #params_quote
0 commit comments