Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,35 @@
public class PolicyManager {
private static final Logger logger = LogManager.getLogger(PolicyManager.class);

record ModuleEntitlements(Map<Class<? extends Entitlement>, List<Entitlement>> entitlementsByType, FileAccessTree fileAccess) {
public static final ModuleEntitlements NONE = new ModuleEntitlements(Map.of(), FileAccessTree.EMPTY);
public static final String UNKNOWN_COMPONENT_NAME = "(unknown)";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can all be at most package private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true. I think they're public from a prior revision of this code.

public static final String SERVER_COMPONENT_NAME = "(server)";
public static final String AGENT_COMPONENT_NAME = "(agent)";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually apm agent. Just "agent" may be confused with the entitlement agent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point.


/**
* @param componentName the plugin name; or else one of the special component names
* like {@link #SERVER_COMPONENT_NAME} or {@link #AGENT_COMPONENT_NAME}.
*/
record ModuleEntitlements(
String componentName,
Map<Class<? extends Entitlement>, List<Entitlement>> entitlementsByType,
FileAccessTree fileAccess
) {

ModuleEntitlements {
entitlementsByType = Map.copyOf(entitlementsByType);
}

public static ModuleEntitlements from(List<Entitlement> entitlements) {
public static ModuleEntitlements none(String componentName) {
return new ModuleEntitlements(componentName, Map.of(), FileAccessTree.EMPTY);
}

public static ModuleEntitlements from(String componentName, List<Entitlement> entitlements) {
var fileEntitlements = entitlements.stream()
.filter(e -> e.getClass().equals(FileEntitlement.class))
.map(e -> (FileEntitlement) e)
.toList();
return new ModuleEntitlements(
componentName,
entitlements.stream().collect(groupingBy(Entitlement::getClass)),
FileAccessTree.of(fileEntitlements)
);
Expand Down Expand Up @@ -184,9 +200,10 @@ private void neverEntitled(Class<?> callerClass, Supplier<String> operationDescr

throw new NotEntitledException(
Strings.format(
"Not entitled: caller [%s], module [%s], operation [%s]",
callerClass,
requestingClass.getModule() == null ? "<none>" : requestingClass.getModule().getName(),
"Not entitled: component [%s], module [%s], class [%s], operation [%s]",
getEntitlements(requestingClass).componentName(),
requestingClass.getModule().getName(),
requestingClass,
operationDescription.get()
)
);
Expand Down Expand Up @@ -240,9 +257,10 @@ public void checkFileRead(Class<?> callerClass, Path path) {
if (entitlements.fileAccess().canRead(path) == false) {
throw new NotEntitledException(
Strings.format(
"Not entitled: caller [%s], module [%s], entitlement [file], operation [read], path [%s]",
callerClass,
"Not entitled: component [%s], module [%s], class [%s], entitlement [file], operation [read], path [%s]",
entitlements.componentName(),
requestingClass.getModule(),
requestingClass,
path
)
);
Expand All @@ -264,9 +282,10 @@ public void checkFileWrite(Class<?> callerClass, Path path) {
if (entitlements.fileAccess().canWrite(path) == false) {
throw new NotEntitledException(
Strings.format(
"Not entitled: caller [%s], module [%s], entitlement [file], operation [write], path [%s]",
callerClass,
"Not entitled: component [%s], module [%s], class [%s], entitlement [file], operation [write], path [%s]",
entitlements.componentName(),
requestingClass.getModule(),
requestingClass,
path
)
);
Expand Down Expand Up @@ -300,30 +319,33 @@ public void checkAllNetworkAccess(Class<?> callerClass) {
}

var classEntitlements = getEntitlements(requestingClass);
if (classEntitlements.hasEntitlement(InboundNetworkEntitlement.class) == false) {
throw new NotEntitledException(
Strings.format(
"Missing entitlement: class [%s], module [%s], entitlement [inbound_network]",
requestingClass,
requestingClass.getModule().getName()
)
);
}
checkFlagEntitlement(classEntitlements, InboundNetworkEntitlement.class, requestingClass);
checkFlagEntitlement(classEntitlements, OutboundNetworkEntitlement.class, requestingClass);
}

if (classEntitlements.hasEntitlement(OutboundNetworkEntitlement.class) == false) {
private static void checkFlagEntitlement(
ModuleEntitlements classEntitlements,
Class<? extends Entitlement> entitlementClass,
Class<?> requestingClass
) {
if (classEntitlements.hasEntitlement(entitlementClass) == false) {
throw new NotEntitledException(
Strings.format(
"Missing entitlement: class [%s], module [%s], entitlement [outbound_network]",
"Not entitled: component [%s], module [%s], class [%s], entitlement [%s]",
classEntitlements.componentName(),
requestingClass.getModule().getName(),
requestingClass,
requestingClass.getModule().getName()
PolicyParser.getEntitlementTypeName(entitlementClass)
)
);
}
logger.debug(
() -> Strings.format(
"Entitled: class [%s], module [%s], entitlements [inbound_network, outbound_network]",
"Entitled: component [%s], module [%s], class [%s], entitlement [%s]",
classEntitlements.componentName(),
requestingClass.getModule().getName(),
requestingClass,
requestingClass.getModule().getName()
PolicyParser.getEntitlementTypeName(entitlementClass)
)
);
}
Expand All @@ -338,19 +360,21 @@ public void checkWriteProperty(Class<?> callerClass, String property) {
if (entitlements.getEntitlements(WriteSystemPropertiesEntitlement.class).anyMatch(e -> e.properties().contains(property))) {
logger.debug(
() -> Strings.format(
"Entitled: class [%s], module [%s], entitlement [write_system_properties], property [%s]",
requestingClass,
"Entitled: component [%s], module [%s], class [%s], entitlement [write_system_properties], property [%s]",
entitlements.componentName(),
requestingClass.getModule().getName(),
requestingClass,
property
)
);
return;
}
throw new NotEntitledException(
Strings.format(
"Missing entitlement: class [%s], module [%s], entitlement [write_system_properties], property [%s]",
requestingClass,
"Not entitled: component [%s], module [%s], class [%s], entitlement [write_system_properties], property [%s]",
entitlements.componentName(),
requestingClass.getModule().getName(),
requestingClass,
property
)
);
Expand All @@ -361,27 +385,7 @@ private void checkEntitlementPresent(Class<?> callerClass, Class<? extends Entit
if (isTriviallyAllowed(requestingClass)) {
return;
}

ModuleEntitlements entitlements = getEntitlements(requestingClass);
if (entitlements.hasEntitlement(entitlementClass)) {
logger.debug(
() -> Strings.format(
"Entitled: class [%s], module [%s], entitlement [%s]",
requestingClass,
requestingClass.getModule().getName(),
PolicyParser.getEntitlementTypeName(entitlementClass)
)
);
return;
}
throw new NotEntitledException(
Strings.format(
"Missing entitlement: class [%s], module [%s], entitlement [%s]",
requestingClass,
requestingClass.getModule().getName(),
PolicyParser.getEntitlementTypeName(entitlementClass)
)
);
checkFlagEntitlement(getEntitlements(requestingClass), entitlementClass, requestingClass);
}

ModuleEntitlements getEntitlements(Class<?> requestingClass) {
Expand All @@ -391,47 +395,44 @@ ModuleEntitlements getEntitlements(Class<?> requestingClass) {
private ModuleEntitlements computeEntitlements(Class<?> requestingClass) {
Module requestingModule = requestingClass.getModule();
if (isServerModule(requestingModule)) {
return getModuleScopeEntitlements(requestingClass, serverEntitlements, requestingModule.getName(), "server");
return getModuleScopeEntitlements(serverEntitlements, requestingModule.getName(), SERVER_COMPONENT_NAME);
}

// plugins
var pluginName = pluginResolver.apply(requestingClass);
if (pluginName != null) {
var pluginEntitlements = pluginsEntitlements.get(pluginName);
if (pluginEntitlements == null) {
return ModuleEntitlements.NONE;
return ModuleEntitlements.none(pluginName);
} else {
final String scopeName;
if (requestingModule.isNamed() == false) {
scopeName = ALL_UNNAMED;
} else {
scopeName = requestingModule.getName();
}
return getModuleScopeEntitlements(requestingClass, pluginEntitlements, scopeName, pluginName);
return getModuleScopeEntitlements(pluginEntitlements, scopeName, pluginName);
}
}

if (requestingModule.isNamed() == false && requestingClass.getPackageName().startsWith(agentsPackageName)) {
// agents are the only thing running non-modular in the system classloader
return ModuleEntitlements.from(agentEntitlements);
return ModuleEntitlements.from(AGENT_COMPONENT_NAME, agentEntitlements);
}

logger.warn("No applicable entitlement policy for class [{}]", requestingClass.getName());
return ModuleEntitlements.NONE;
return ModuleEntitlements.none(UNKNOWN_COMPONENT_NAME);
}

private ModuleEntitlements getModuleScopeEntitlements(
Class<?> callerClass,
Map<String, List<Entitlement>> scopeEntitlements,
String moduleName,
String component
String componentName
) {
var entitlements = scopeEntitlements.get(moduleName);
if (entitlements == null) {
logger.warn("No applicable entitlement policy for [{}], module [{}], class [{}]", component, moduleName, callerClass);
return ModuleEntitlements.NONE;
return ModuleEntitlements.none(componentName);
}
return ModuleEntitlements.from(entitlements);
return ModuleEntitlements.from(componentName, entitlements);
}

private static boolean isServerModule(Module requestingModule) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import static java.util.Map.entry;
import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED;
import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.SERVER_COMPONENT_NAME;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -77,9 +78,9 @@ public void testGetEntitlementsThrowsOnMissingPluginUnnamedModule() {
var callerClass = this.getClass();
var requestingModule = callerClass.getModule();

assertEquals("No policy for the unnamed module", ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass));
assertEquals("No policy for the unnamed module", ModuleEntitlements.none("plugin1"), policyManager.getEntitlements(callerClass));

assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
assertEquals(Map.of(requestingModule, ModuleEntitlements.none("plugin1")), policyManager.moduleEntitlementsMap);
}

public void testGetEntitlementsThrowsOnMissingPolicyForPlugin() {
Expand All @@ -96,9 +97,9 @@ public void testGetEntitlementsThrowsOnMissingPolicyForPlugin() {
var callerClass = this.getClass();
var requestingModule = callerClass.getModule();

assertEquals("No policy for this plugin", ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass));
assertEquals("No policy for this plugin", ModuleEntitlements.none("plugin1"), policyManager.getEntitlements(callerClass));

assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
assertEquals(Map.of(requestingModule, ModuleEntitlements.none("plugin1")), policyManager.moduleEntitlementsMap);
}

public void testGetEntitlementsFailureIsCached() {
Expand All @@ -115,14 +116,14 @@ public void testGetEntitlementsFailureIsCached() {
var callerClass = this.getClass();
var requestingModule = callerClass.getModule();

assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass));
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
assertEquals(ModuleEntitlements.none("plugin1"), policyManager.getEntitlements(callerClass));
assertEquals(Map.of(requestingModule, ModuleEntitlements.none("plugin1")), policyManager.moduleEntitlementsMap);

// A second time
assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass));
assertEquals(ModuleEntitlements.none("plugin1"), policyManager.getEntitlements(callerClass));

// Nothing new in the map
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
assertEquals(Map.of(requestingModule, ModuleEntitlements.none("plugin1")), policyManager.moduleEntitlementsMap);
}

public void testGetEntitlementsReturnsEntitlementsForPluginUnnamedModule() {
Expand Down Expand Up @@ -159,9 +160,13 @@ public void testGetEntitlementsThrowsOnMissingPolicyForServer() throws ClassNotF
var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer");
var requestingModule = mockServerClass.getModule();

assertEquals("No policy for this module in server", ModuleEntitlements.NONE, policyManager.getEntitlements(mockServerClass));
assertEquals(
"No policy for this module in server",
ModuleEntitlements.none(SERVER_COMPONENT_NAME),
policyManager.getEntitlements(mockServerClass)
);

assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
assertEquals(Map.of(requestingModule, ModuleEntitlements.none(SERVER_COMPONENT_NAME)), policyManager.moduleEntitlementsMap);
}

public void testGetEntitlementsReturnsEntitlementsForServerModule() throws ClassNotFoundException {
Expand Down