1111
1212import  org .elasticsearch .client .internal .node .NodeClient ;
1313import  org .elasticsearch .common .Strings ;
14+ import  org .elasticsearch .core .CheckedConsumer ;
1415import  org .elasticsearch .core .CheckedRunnable ;
1516import  org .elasticsearch .core .SuppressForbidden ;
1617import  org .elasticsearch .entitlement .runtime .api .NotEntitledException ;
18+ import  org .elasticsearch .env .Environment ;
1719import  org .elasticsearch .logging .LogManager ;
1820import  org .elasticsearch .logging .Logger ;
1921import  org .elasticsearch .rest .BaseRestHandler ;
@@ -70,7 +72,7 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
7072    private  static  final  Logger  logger  = LogManager .getLogger (RestEntitlementsCheckAction .class );
7173
7274    record  CheckAction (
73-         CheckedRunnable < Exception > action ,
75+         CheckedConsumer < Environment ,  Exception > action ,
7476        EntitlementTest .ExpectedAccess  expectedAccess ,
7577        Class <? extends  Exception > expectedExceptionIfDenied ,
7678        Integer  fromJavaVersion 
@@ -79,15 +81,15 @@ record CheckAction(
7981         * These cannot be granted to plugins, so our test plugins cannot test the "allowed" case. 
8082         */ 
8183        static  CheckAction  deniedToPlugins (CheckedRunnable <Exception > action ) {
82-             return  new  CheckAction (action , SERVER_ONLY , NotEntitledException .class , null );
84+             return  new  CheckAction (env  ->  action . run () , SERVER_ONLY , NotEntitledException .class , null );
8385        }
8486
8587        static  CheckAction  forPlugins (CheckedRunnable <Exception > action ) {
86-             return  new  CheckAction (action , PLUGINS , NotEntitledException .class , null );
88+             return  new  CheckAction (env  ->  action . run () , PLUGINS , NotEntitledException .class , null );
8789        }
8890
8991        static  CheckAction  alwaysDenied (CheckedRunnable <Exception > action ) {
90-             return  new  CheckAction (action , ALWAYS_DENIED , NotEntitledException .class , null );
92+             return  new  CheckAction (env  ->  action . run () , ALWAYS_DENIED , NotEntitledException .class , null );
9193        }
9294    }
9395
@@ -135,7 +137,7 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
135137            entry (
136138                "createInetAddressResolverProvider" ,
137139                new  CheckAction (
138-                     VersionSpecificNetworkChecks :: createInetAddressResolverProvider ,
140+                     env  ->  VersionSpecificNetworkChecks . createInetAddressResolverProvider () ,
139141                    SERVER_ONLY ,
140142                    NotEntitledException .class ,
141143                    18 
@@ -215,6 +217,12 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
215217        .filter (entry  -> entry .getValue ().fromJavaVersion () == null  || Runtime .version ().feature () >= entry .getValue ().fromJavaVersion ())
216218        .collect (Collectors .toUnmodifiableMap (Entry ::getKey , Entry ::getValue ));
217219
220+     private  final  Environment  environment ;
221+ 
222+     public  RestEntitlementsCheckAction (Environment  environment ) {
223+         this .environment  = environment ;
224+     }
225+ 
218226    @ SuppressForbidden (reason  = "Need package private methods so we don't have to make them all public" )
219227    private  static  Method [] getDeclaredMethods (Class <?> clazz ) {
220228        return  clazz .getDeclaredMethods ();
@@ -230,13 +238,10 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
230238            if  (Modifier .isStatic (method .getModifiers ()) == false ) {
231239                throw  new  AssertionError ("Entitlement test method ["  + method  + "] must be static" );
232240            }
233-             if  (method .getParameterTypes ().length  != 0 ) {
234-                 throw  new  AssertionError ("Entitlement test method ["  + method  + "] must not have parameters" );
235-             }
236- 
237-             CheckedRunnable <Exception > runnable  = () -> {
241+             final  CheckedConsumer <Environment , Exception > call  = createConsumerForMethod (method );
242+             CheckedConsumer <Environment , Exception > runnable  = env  -> {
238243                try  {
239-                     method . invoke ( null );
244+                     call . accept ( env );
240245                } catch  (IllegalAccessException  e ) {
241246                    throw  new  AssertionError (e );
242247                } catch  (InvocationTargetException  e ) {
@@ -258,6 +263,17 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
258263        return  entries .stream ();
259264    }
260265
266+     private  static  CheckedConsumer <Environment , Exception > createConsumerForMethod (Method  method ) {
267+         Class <?>[] parameters  = method .getParameterTypes ();
268+         if  (parameters .length  == 0 ) {
269+             return  env  -> method .invoke (null );
270+         }
271+         if  (parameters .length  == 1  && parameters [0 ].equals (Environment .class )) {
272+             return  env  -> method .invoke (null , env );
273+         }
274+         throw  new  AssertionError ("Entitlement test method ["  + method  + "] must have no parameters or 1 parameter (Environment)" );
275+     }
276+ 
261277    private  static  void  createURLStreamHandlerProvider () {
262278        var  x  = new  URLStreamHandlerProvider () {
263279            @ Override 
@@ -421,6 +437,14 @@ public static Set<String> getCheckActionsAllowedInPlugins() {
421437            .collect (Collectors .toSet ());
422438    }
423439
440+     public  static  Set <String > getAlwaysAllowedCheckActions () {
441+         return  checkActions .entrySet ()
442+             .stream ()
443+             .filter (kv  -> kv .getValue ().expectedAccess ().equals (ALWAYS_ALLOWED ))
444+             .map (Entry ::getKey )
445+             .collect (Collectors .toSet ());
446+     }
447+ 
424448    public  static  Set <String > getDeniableCheckActions () {
425449        return  checkActions .entrySet ()
426450            .stream ()
@@ -455,7 +479,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
455479            logger .info ("Calling check action [{}]" , actionName );
456480            RestResponse  response ;
457481            try  {
458-                 checkAction .action ().run ( );
482+                 checkAction .action ().accept ( environment );
459483                response  = new  RestResponse (RestStatus .OK , Strings .format ("Succesfully executed action [%s]" , actionName ));
460484            } catch  (Exception  e ) {
461485                var  statusCode  = checkAction .expectedExceptionIfDenied .isInstance (e )
@@ -468,5 +492,4 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
468492            channel .sendResponse (response );
469493        };
470494    }
471- 
472495}
0 commit comments