1111
1212import org .elasticsearch .client .internal .node .NodeClient ;
1313import org .elasticsearch .common .Strings ;
14+ import org .elasticsearch .core .CheckedConsumer ;
1415import org .elasticsearch .core .CheckedRunnable ;
1516import org .elasticsearch .core .SuppressForbidden ;
17+ import org .elasticsearch .env .Environment ;
1618import org .elasticsearch .logging .LogManager ;
1719import org .elasticsearch .logging .Logger ;
1820import org .elasticsearch .rest .BaseRestHandler ;
6870public class RestEntitlementsCheckAction extends BaseRestHandler {
6971 private static final Logger logger = LogManager .getLogger (RestEntitlementsCheckAction .class );
7072
71- record CheckAction (CheckedRunnable <Exception > action , EntitlementTest .ExpectedAccess expectedAccess , Integer fromJavaVersion ) {
73+ record CheckAction (
74+ CheckedConsumer <Environment , Exception > action ,
75+ EntitlementTest .ExpectedAccess expectedAccess ,
76+ Integer fromJavaVersion
77+ ) {
7278 /**
7379 * These cannot be granted to plugins, so our test plugins cannot test the "allowed" case.
7480 */
7581 static CheckAction deniedToPlugins (CheckedRunnable <Exception > action ) {
76- return new CheckAction (action , SERVER_ONLY , null );
82+ return new CheckAction (env -> action . run () , SERVER_ONLY , null );
7783 }
7884
7985 static CheckAction forPlugins (CheckedRunnable <Exception > action ) {
80- return new CheckAction (action , PLUGINS , null );
86+ return new CheckAction (env -> action . run () , PLUGINS , null );
8187 }
8288
8389 static CheckAction alwaysDenied (CheckedRunnable <Exception > action ) {
84- return new CheckAction (action , ALWAYS_DENIED , null );
90+ return new CheckAction (env -> action . run () , ALWAYS_DENIED , null );
8591 }
8692 }
8793
@@ -128,7 +134,7 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
128134 entry ("responseCache_setDefault" , alwaysDenied (RestEntitlementsCheckAction ::setDefaultResponseCache )),
129135 entry (
130136 "createInetAddressResolverProvider" ,
131- new CheckAction (VersionSpecificNetworkChecks :: createInetAddressResolverProvider , SERVER_ONLY , 18 )
137+ new CheckAction (env -> VersionSpecificNetworkChecks . createInetAddressResolverProvider () , SERVER_ONLY , 18 )
132138 ),
133139 entry ("createURLStreamHandlerProvider" , alwaysDenied (RestEntitlementsCheckAction ::createURLStreamHandlerProvider )),
134140 entry ("createURLWithURLStreamHandler" , alwaysDenied (RestEntitlementsCheckAction ::createURLWithURLStreamHandler )),
@@ -204,6 +210,12 @@ static CheckAction alwaysDenied(CheckedRunnable<Exception> action) {
204210 .filter (entry -> entry .getValue ().fromJavaVersion () == null || Runtime .version ().feature () >= entry .getValue ().fromJavaVersion ())
205211 .collect (Collectors .toUnmodifiableMap (Entry ::getKey , Entry ::getValue ));
206212
213+ private final Environment environment ;
214+
215+ public RestEntitlementsCheckAction (Environment environment ) {
216+ this .environment = environment ;
217+ }
218+
207219 @ SuppressForbidden (reason = "Need package private methods so we don't have to make them all public" )
208220 private static Method [] getDeclaredMethods (Class <?> clazz ) {
209221 return clazz .getDeclaredMethods ();
@@ -219,13 +231,10 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
219231 if (Modifier .isStatic (method .getModifiers ()) == false ) {
220232 throw new AssertionError ("Entitlement test method [" + method + "] must be static" );
221233 }
222- if (method .getParameterTypes ().length != 0 ) {
223- throw new AssertionError ("Entitlement test method [" + method + "] must not have parameters" );
224- }
225-
226- CheckedRunnable <Exception > runnable = () -> {
234+ final CheckedConsumer <Environment , Exception > call = createConsumerForMethod (method );
235+ CheckedConsumer <Environment , Exception > runnable = env -> {
227236 try {
228- method . invoke ( null );
237+ call . accept ( env );
229238 } catch (IllegalAccessException e ) {
230239 throw new AssertionError (e );
231240 } catch (InvocationTargetException e ) {
@@ -242,6 +251,17 @@ private static Stream<Entry<String, CheckAction>> getTestEntries(Class<?> action
242251 return entries .stream ();
243252 }
244253
254+ private static CheckedConsumer <Environment , Exception > createConsumerForMethod (Method method ) {
255+ Class <?>[] parameters = method .getParameterTypes ();
256+ if (parameters .length == 0 ) {
257+ return env -> method .invoke (null );
258+ }
259+ if (parameters .length == 1 && parameters [0 ].equals (Environment .class )) {
260+ return env -> method .invoke (null , env );
261+ }
262+ throw new AssertionError ("Entitlement test method [" + method + "] must have no parameters or 1 parameter (Environment)" );
263+ }
264+
245265 private static void createURLStreamHandlerProvider () {
246266 var x = new URLStreamHandlerProvider () {
247267 @ Override
@@ -405,6 +425,14 @@ public static Set<String> getCheckActionsAllowedInPlugins() {
405425 .collect (Collectors .toSet ());
406426 }
407427
428+ public static Set <String > getAlwaysAllowedCheckActions () {
429+ return checkActions .entrySet ()
430+ .stream ()
431+ .filter (kv -> kv .getValue ().expectedAccess ().equals (ALWAYS_ALLOWED ))
432+ .map (Entry ::getKey )
433+ .collect (Collectors .toSet ());
434+ }
435+
408436 public static Set <String > getDeniableCheckActions () {
409437 return checkActions .entrySet ()
410438 .stream ()
@@ -437,10 +465,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
437465
438466 return channel -> {
439467 logger .info ("Calling check action [{}]" , actionName );
440- checkAction .action ().run ( );
468+ checkAction .action ().accept ( environment );
441469 logger .debug ("Check action [{}] returned" , actionName );
442470 channel .sendResponse (new RestResponse (RestStatus .OK , Strings .format ("Succesfully executed action [%s]" , actionName )));
443471 };
444472 }
445-
446473}
0 commit comments