Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ public static boolean isSpatialPoint(DataType t) {
return t == GEO_POINT || t == CARTESIAN_POINT;
}

public static boolean isSpatialShape(DataType t) {
return t == GEO_SHAPE || t == CARTESIAN_SHAPE;
}

public static boolean isSpatialGeo(DataType t) {
return t == GEO_POINT || t == GEO_SHAPE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.test.TestBlockFactory;
import org.elasticsearch.indices.CrankyCircuitBreakerService;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.license.internal.XPackLicenseStatus;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.LicenseAware;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
Expand All @@ -51,7 +55,6 @@
import org.junit.After;
import org.junit.AfterClass;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -701,7 +704,8 @@ public void testSerializationOfSimple() {
*/
@AfterClass
public static void testFunctionInfo() {
Logger log = LogManager.getLogger(getTestClass());
Class<?> testClass = getTestClass();
Logger log = LogManager.getLogger(testClass);
FunctionDefinition definition = definition(functionName());
if (definition == null) {
log.info("Skipping function info checks because the function isn't registered");
Expand All @@ -724,7 +728,7 @@ public static void testFunctionInfo() {
for (int i = 0; i < args.size(); i++) {
typesFromSignature.add(new HashSet<>());
}
for (Map.Entry<List<DataType>, DataType> entry : signatures(getTestClass()).entrySet()) {
for (Map.Entry<List<DataType>, DataType> entry : signatures(testClass).entrySet()) {
List<DataType> types = entry.getKey();
for (int i = 0; i < args.size() && i < types.size(); i++) {
typesFromSignature.get(i).add(types.get(i).esNameIfPossible());
Expand Down Expand Up @@ -767,6 +771,101 @@ public static void testFunctionInfo() {
assertEquals(returnFromSignature, returnTypes);
}

/**
* This test is meant to validate that the license checks documented match those enforced.
* The expectations are set in the test class using a method with this signature:
* <code>
* public static License.OperationMode licenseRequirement(List&lt;DataType&gt; fieldTypes);
* </code>
* License enforcement in the function class is achieved using the interface <code>LicenseAware</code>.
* This test will make sure the two are in agreement, and does not require that the function class actually
* report its license level. If we add license checks to any function, but fail to also add the expected
* license level to the test class, this test will fail.
*/
@AfterClass
public static void testFunctionLicenseChecks() throws Exception {
Class<?> testClass = getTestClass();
Logger log = LogManager.getLogger(testClass);
FunctionDefinition definition = definition(functionName());
if (definition == null) {
log.info("Skipping function info checks because the function isn't registered");
return;
}
log.info("Running function license checks");
DocsV3Support.LicenseRequirementChecker licenseChecker = new DocsV3Support.LicenseRequirementChecker(testClass);
License.OperationMode functionLicense = licenseChecker.invoke(null);
Constructor<?> ctor = constructorWithFunctionInfo(definition.clazz());
if (LicenseAware.class.isAssignableFrom(definition.clazz()) == false) {
// Perform simpler no-signature tests
assertThat(
"Function " + definition.name() + " should be licensed under " + functionLicense,
functionLicense,
equalTo(License.OperationMode.BASIC)
);
return;
}
// For classes with LicenseAware, we need to check that the license is correct
TestCheckLicense checkLicense = new TestCheckLicense();

// Go through all signatures and assert that the license is as expected
signatures(testClass).forEach((signature, returnType) -> {
try {
License.OperationMode license = licenseChecker.invoke(signature);
assertNotNull("License should not be null", license);

// Construct an instance of the class and then call it's licenseCheck method, and compare the results
Object[] args = new Object[signature.size() + 1];
args[0] = Source.EMPTY;
for (int i = 0; i < signature.size(); i++) {
args[i + 1] = new Literal(Source.EMPTY, null, signature.get(i));
}
Object instance = ctor.newInstance(args);
// Check that object implements the LicenseAware interface
if (LicenseAware.class.isAssignableFrom(instance.getClass())) {
LicenseAware licenseAware = (LicenseAware) instance;
switch (license) {
case BASIC -> checkLicense.assertLicenseCheck(licenseAware, signature, true, true, true);
case PLATINUM -> checkLicense.assertLicenseCheck(licenseAware, signature, false, true, true);
case ENTERPRISE -> checkLicense.assertLicenseCheck(licenseAware, signature, false, false, true);
}
} else {
fail("Function " + definition.name() + " does not implement LicenseAware");
}
} catch (Exception e) {
fail(e);
}
});
}

private static class TestCheckLicense {
XPackLicenseState basicLicense = makeLicenseState(License.OperationMode.BASIC);
XPackLicenseState platinumLicense = makeLicenseState(License.OperationMode.PLATINUM);
XPackLicenseState enterpriseLicense = makeLicenseState(License.OperationMode.ENTERPRISE);

private void assertLicenseCheck(
LicenseAware licenseAware,
List<DataType> signature,
boolean allowsBasic,
boolean allowsPlatinum,
boolean allowsEnterprise
) {
boolean basic = licenseAware.licenseCheck(basicLicense);
boolean platinum = licenseAware.licenseCheck(platinumLicense);
boolean enterprise = licenseAware.licenseCheck(enterpriseLicense);
assertThat("Basic license should be accepted for " + signature, basic, equalTo(allowsBasic));
assertThat("Platinum license should be accepted for " + signature, platinum, equalTo(allowsPlatinum));
assertThat("Enterprise license should be accepted for " + signature, enterprise, equalTo(allowsEnterprise));
}

private void assertLicenseCheck(List<DataType> signature, boolean allowed, boolean expected) {
assertThat("Basic license should " + (expected ? "" : "not ") + "be accepted for " + signature, allowed, equalTo(expected));
}
}

private static XPackLicenseState makeLicenseState(License.OperationMode mode) {
return new XPackLicenseState(System::currentTimeMillis, new XPackLicenseStatus(mode, true, ""));
}

/**
* Asserts the result of a test case matches the expected result and warnings.
* <p>
Expand Down Expand Up @@ -836,7 +935,7 @@ public static Map<List<DataType>, DataType> signatures(Class<?> testClass) {
}

@AfterClass
public static void renderDocs() throws IOException {
public static void renderDocs() throws Exception {
if (System.getProperty("generateDocs") == null) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.common.Strings;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.license.License;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -46,6 +47,7 @@
import java.io.InputStreamReader;
import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -107,7 +109,7 @@ static OperatorsDocsSupport forOperators(String name, Class<?> testClass) {
return new OperatorsDocsSupport(name, testClass);
}

static void renderDocs(String name, Class<?> testClass) throws IOException {
static void renderDocs(String name, Class<?> testClass) throws Exception {
if (OPERATORS.containsKey(name)) {
var docs = DocsV3Support.forOperators(name, testClass);
docs.renderSignature();
Expand All @@ -126,7 +128,7 @@ public static void renderNegatedOperator(
String name,
Function<String, String> description,
Class<?> testClass
) throws IOException {
) throws Exception {
var docs = forOperators("not " + name.toLowerCase(Locale.ROOT), testClass);
docs.renderDocsForNegatedOperators(ctor, description);
}
Expand Down Expand Up @@ -272,12 +274,46 @@ public void writeToTempDir(Path dir, String extension, String str) throws IOExce
}
}

/**
* This class is used to check if a license requirement method exists in the test class.
* This is used to add license requirement information to the generated documentation.
*/
public static class LicenseRequirementChecker {
private Method staticMethod;
private Function<List<DataType>, License.OperationMode> fallbackLambda;

public LicenseRequirementChecker(Class<?> testClass) {
try {
staticMethod = testClass.getMethod("licenseRequirement", List.class);
if (License.OperationMode.class.equals(staticMethod.getReturnType()) == false
|| java.lang.reflect.Modifier.isStatic(staticMethod.getModifiers()) == false) {
staticMethod = null; // Reset if the method doesn't match the signature
}
} catch (NoSuchMethodException e) {
staticMethod = null;
}

if (staticMethod == null) {
fallbackLambda = fieldTypes -> License.OperationMode.BASIC;
}
}

public License.OperationMode invoke(List<DataType> fieldTypes) throws Exception {
if (staticMethod != null) {
return (License.OperationMode) staticMethod.invoke(null, fieldTypes);
} else {
return fallbackLambda.apply(fieldTypes);
}
}
}

protected final String category;
protected final String name;
protected final FunctionDefinition definition;
protected final Logger logger;
private final Supplier<Map<List<DataType>, DataType>> signatures;
private TempFileWriter tempFileWriter;
private final LicenseRequirementChecker licenseChecker;

protected DocsV3Support(String category, String name, Class<?> testClass, Supplier<Map<List<DataType>, DataType>> signatures) {
this(category, name, null, testClass, signatures);
Expand All @@ -296,6 +332,7 @@ private DocsV3Support(
this.logger = LogManager.getLogger(testClass);
this.signatures = signatures;
this.tempFileWriter = new DocsFileWriter();
this.licenseChecker = new LicenseRequirementChecker(testClass);
}

/** Used in tests to capture output for asserting on the content */
Expand Down Expand Up @@ -460,7 +497,7 @@ void writeToTempKibanaDir(String subdir, String extension, String str) throws IO

protected abstract void renderSignature() throws IOException;

protected abstract void renderDocs() throws IOException;
protected abstract void renderDocs() throws Exception;

static class FunctionDocsSupport extends DocsV3Support {
private FunctionDocsSupport(String name, Class<?> testClass) {
Expand Down Expand Up @@ -488,7 +525,7 @@ protected void renderSignature() throws IOException {
}

@Override
protected void renderDocs() throws IOException {
protected void renderDocs() throws Exception {
if (definition == null) {
logger.info("Skipping rendering docs because the function '{}' isn't registered", name);
} else {
Expand All @@ -497,7 +534,7 @@ protected void renderDocs() throws IOException {
}
}

private void renderDocs(FunctionDefinition definition) throws IOException {
private void renderDocs(FunctionDefinition definition) throws Exception {
EsqlFunctionRegistry.FunctionDescription description = EsqlFunctionRegistry.description(definition);
if (name.equals("case")) {
/*
Expand Down Expand Up @@ -711,7 +748,7 @@ public void renderSignature() throws IOException {
}

@Override
public void renderDocs() throws IOException {
public void renderDocs() throws Exception {
Constructor<?> ctor = constructorWithFunctionInfo(op.clazz());
if (ctor != null) {
FunctionInfo functionInfo = ctor.getAnnotation(FunctionInfo.class);
Expand All @@ -722,7 +759,7 @@ public void renderDocs() throws IOException {
}
}

void renderDocsForNegatedOperators(Constructor<?> ctor, Function<String, String> description) throws IOException {
void renderDocsForNegatedOperators(Constructor<?> ctor, Function<String, String> description) throws Exception {
String baseName = name.toLowerCase(Locale.ROOT).replace("not ", "");
OperatorConfig op = OPERATORS.get(baseName);
assert op != null;
Expand Down Expand Up @@ -795,7 +832,7 @@ public Example[] examples() {
}

void renderDocsForOperators(String name, String titleName, Constructor<?> ctor, FunctionInfo info, boolean variadic)
throws IOException {
throws Exception {
renderKibanaInlineDocs(name, titleName, info);

var params = ctor.getParameters();
Expand Down Expand Up @@ -999,7 +1036,7 @@ void renderKibanaFunctionDefinition(
FunctionInfo info,
List<EsqlFunctionRegistry.ArgSignature> args,
boolean variadic
) throws IOException {
) throws Exception {

try (XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint().lfAtEnd().startObject()) {
builder.field(
Expand All @@ -1019,6 +1056,10 @@ void renderKibanaFunctionDefinition(
});
}
builder.field("name", name);
License.OperationMode license = licenseChecker.invoke(null);
if (license != null && license != License.OperationMode.BASIC) {
builder.field("license", license.toString());
}
if (titleName != null && titleName.equals(name) == false) {
builder.field("titleName", titleName);
}
Expand Down Expand Up @@ -1073,6 +1114,10 @@ void renderKibanaFunctionDefinition(
builder.endObject();
}
builder.endArray();
license = licenseChecker.invoke(sig.getKey());
if (license != null && license != License.OperationMode.BASIC) {
builder.field("license", license.toString());
}
builder.field("variadic", variadic);
builder.field("returnType", sig.getValue().esNameIfPossible());
builder.endObject();
Expand Down
Loading