1111
1212import org .elasticsearch .client .internal .node .NodeClient ;
1313import org .elasticsearch .common .Strings ;
14+ import org .elasticsearch .core .CheckedConsumer ;
1415import org .elasticsearch .core .CheckedRunnable ;
1516import org .elasticsearch .core .SuppressForbidden ;
1617import org .elasticsearch .entitlement .runtime .api .NotEntitledException ;
18+ import org .elasticsearch .env .Environment ;
1719import org .elasticsearch .logging .LogManager ;
1820import org .elasticsearch .logging .Logger ;
1921import org .elasticsearch .rest .BaseRestHandler ;
@@ -70,7 +72,7 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
7072 private static final Logger logger = LogManager .getLogger (RestEntitlementsCheckAction .class );
7173
7274 record CheckAction (
73- CheckedRunnable < Exception > action ,
75+ CheckedConsumer < Environment , Exception > action ,
7476 EntitlementTest .ExpectedAccess expectedAccess ,
7577 Class <? extends Exception > expectedExceptionIfDenied ,
7678 Integer fromJavaVersion
@@ -79,15 +81,15 @@ record CheckAction(
7981 * These cannot be granted to plugins, so our test plugins cannot test the "allowed" case.
8082 */
8183 static CheckAction deniedToPlugins (CheckedRunnable <Exception > action ) {
82- return new CheckAction (action , SERVER_ONLY , NotEntitledException .class , null );
84+ return new CheckAction (env -> action . run () , SERVER_ONLY , NotEntitledException .class , null );
8385 }
8486
8587 static CheckAction forPlugins (CheckedRunnable <Exception > action ) {
86- return new CheckAction (action , PLUGINS , NotEntitledException .class , null );
88+ return new CheckAction (env -> action . run () , PLUGINS , NotEntitledException .class , null );
8789 }
8890
8991 static CheckAction alwaysDenied (CheckedRunnable <Exception > action ) {
90- return new CheckAction (action , ALWAYS_DENIED , NotEntitledException .class , null );
92+ return new CheckAction (env -> action . run () , ALWAYS_DENIED , NotEntitledException .class , null );
9193 }
9294 }
9395
@@ -135,7 +137,7 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
135137 entry (
136138 "createInetAddressResolverProvider" ,
137139 new CheckAction (
138- VersionSpecificNetworkChecks :: createInetAddressResolverProvider ,
140+ env -> VersionSpecificNetworkChecks . createInetAddressResolverProvider () ,
139141 SERVER_ONLY ,
140142 NotEntitledException .class ,
141143 18
@@ -217,6 +219,12 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
217219 .filter (entry -> entry .getValue ().fromJavaVersion () == null || Runtime .version ().feature () >= entry .getValue ().fromJavaVersion ())
218220 .collect (Collectors .toUnmodifiableMap (Entry ::getKey , Entry ::getValue ));
219221
222+ private final Environment environment ;
223+
224+ public RestEntitlementsCheckAction (Environment environment ) {
225+ this .environment = environment ;
226+ }
227+
220228 @ SuppressForbidden (reason = "Need package private methods so we don't have to make them all public" )
221229 private static Method [] getDeclaredMethods (Class <?> clazz ) {
222230 return clazz .getDeclaredMethods ();
@@ -232,13 +240,10 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
232240 if (Modifier .isStatic (method .getModifiers ()) == false ) {
233241 throw new AssertionError ("Entitlement test method [" + method + "] must be static" );
234242 }
235- if (method .getParameterTypes ().length != 0 ) {
236- throw new AssertionError ("Entitlement test method [" + method + "] must not have parameters" );
237- }
238-
239- CheckedRunnable <Exception > runnable = () -> {
243+ final CheckedConsumer <Environment , Exception > call = createConsumerForMethod (method );
244+ CheckedConsumer <Environment , Exception > runnable = env -> {
240245 try {
241- method . invoke ( null );
246+ call . accept ( env );
242247 } catch (IllegalAccessException e ) {
243248 throw new AssertionError (e );
244249 } catch (InvocationTargetException e ) {
@@ -260,6 +265,17 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
260265 return entries .stream ();
261266 }
262267
268+ private static CheckedConsumer <Environment , Exception > createConsumerForMethod (Method method ) {
269+ Class <?>[] parameters = method .getParameterTypes ();
270+ if (parameters .length == 0 ) {
271+ return env -> method .invoke (null );
272+ }
273+ if (parameters .length == 1 && parameters [0 ].equals (Environment .class )) {
274+ return env -> method .invoke (null , env );
275+ }
276+ throw new AssertionError ("Entitlement test method [" + method + "] must have no parameters or 1 parameter (Environment)" );
277+ }
278+
263279 private static void createURLStreamHandlerProvider () {
264280 var x = new URLStreamHandlerProvider () {
265281 @ Override
@@ -423,6 +439,14 @@ public static Set<String> getCheckActionsAllowedInPlugins() {
423439 .collect (Collectors .toSet ());
424440 }
425441
442+ public static Set <String > getAlwaysAllowedCheckActions () {
443+ return checkActions .entrySet ()
444+ .stream ()
445+ .filter (kv -> kv .getValue ().expectedAccess ().equals (ALWAYS_ALLOWED ))
446+ .map (Entry ::getKey )
447+ .collect (Collectors .toSet ());
448+ }
449+
426450 public static Set <String > getDeniableCheckActions () {
427451 return checkActions .entrySet ()
428452 .stream ()
@@ -457,7 +481,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
457481 logger .info ("Calling check action [{}]" , actionName );
458482 RestResponse response ;
459483 try {
460- checkAction .action ().run ( );
484+ checkAction .action ().accept ( environment );
461485 response = new RestResponse (RestStatus .OK , Strings .format ("Succesfully executed action [%s]" , actionName ));
462486 } catch (Exception e ) {
463487 var statusCode = checkAction .expectedExceptionIfDenied .isInstance (e )
@@ -470,5 +494,4 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
470494 channel .sendResponse (response );
471495 };
472496 }
473-
474497}
0 commit comments