Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public enum FeatureFlag {
Version.fromString("9.2.0"),
null
),
RANDOM_SAMPLING("es.random_sampling_feature_flag_enabled=true", Version.fromString("9.2.0"), null);
RANDOM_SAMPLING("es.random_sampling_feature_flag_enabled=true", Version.fromString("9.2.0"), null),
INFERENCE_API_CCM("es.inference_api_ccm_feature_flag_enabled=true", Version.fromString("9.3.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMModel;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMStorageService;
import org.junit.Before;

import java.util.Collection;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;

public class CCMStorageServiceIT extends ESSingleNodeTestCase {
private CCMStorageService ccmStorageService;

@Before
public void createComponents() {
ccmStorageService = node().injector().getInstance(CCMStorageService.class);
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(LocalStateInferencePlugin.class);
}

public void testStoreAndGetCCMModel() {
var ccmModel = new CCMModel(new SecureString("secret".toCharArray()));
var storeListener = new PlainActionFuture<Void>();
ccmStorageService.store(ccmModel, storeListener);

assertNull(storeListener.actionGet(TimeValue.THIRTY_SECONDS));

var getListener = new PlainActionFuture<CCMModel>();
ccmStorageService.get(getListener);

assertThat(getListener.actionGet(TimeValue.THIRTY_SECONDS), is(ccmModel));
}

public void testGet_ThrowsResourceNotFoundException_WhenCCMIndexDoesNotExist() {
var getListener = new PlainActionFuture<CCMModel>();
ccmStorageService.get(getListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getListener.actionGet(TimeValue.THIRTY_SECONDS));
assertThat(exception.getMessage(), is("CCM configuration not found"));
}

public void testGet_ThrowsResourceNotFoundException_WhenCCMConfigurationDocumentDoesNotExist() {
storeCorruptCCMModel("id");

var getListener = new PlainActionFuture<CCMModel>();
ccmStorageService.get(getListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getListener.actionGet(TimeValue.THIRTY_SECONDS));
assertThat(exception.getMessage(), is("CCM configuration not found"));
}

public void testGetCCMModel_ThrowsException_WhenStoredModelIsCorrupted() {
storeCorruptCCMModel(CCMStorageService.CCM_DOC_ID);

var getListener = new PlainActionFuture<CCMModel>();
ccmStorageService.get(getListener);

var exception = expectThrows(ElasticsearchException.class, () -> getListener.actionGet(TimeValue.THIRTY_SECONDS));
assertThat(exception.getMessage(), containsString("Failed to retrieve CCM configuration"));
assertThat(exception.getCause().getMessage(), containsString("Required [api_key]"));
}

private void storeCorruptCCMModel(String id) {
var corruptedSource = """
{

}
""";

var response = client().prepareIndex()
.setSource(corruptedSource, XContentType.JSON)
.setIndex(CCMIndex.INDEX_NAME)
.setId(id)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute()
.actionGet(TimeValue.THIRTY_SECONDS);

assertThat(response.getResult(), is(DocWriteResponse.Result.CREATED));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureFlag;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMStorageService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
Expand All @@ -160,6 +163,7 @@
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
Expand Down Expand Up @@ -403,6 +407,10 @@ public Collection<?> createComponents(PluginServices services) {
)
);

if (CCMFeatureFlag.FEATURE_FLAG.isEnabled()) {
components.add(new CCMStorageService(services.client()));
}

return components;
}

Expand Down Expand Up @@ -491,6 +499,20 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {

@Override
public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings settings) {
List<SystemIndexDescriptor> ccmIndexDescriptor = CCMFeatureFlag.FEATURE_FLAG.isEnabled()
? List.of(
SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
.setIndexPattern(CCMIndex.INDEX_PATTERN)
.setPrimaryIndex(CCMIndex.INDEX_NAME)
.setDescription("Contains Elastic Inference Service Cloud Connected Mode settings")
.setMappings(CCMIndex.mappings())
.setSettings(CCMIndex.settings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setNetNew()
.build()
)
: List.of();

var inferenceIndexV1Descriptor = SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
Expand All @@ -503,29 +525,32 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.build();

return List.of(
SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
.setIndexPattern(InferenceIndex.INDEX_PATTERN)
.setAliasName(InferenceIndex.INDEX_ALIAS)
.setPrimaryIndex(InferenceIndex.INDEX_NAME)
.setDescription("Contains inference service and model configuration")
.setMappings(InferenceIndex.mappings())
.setSettings(getIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor))
.build(),
SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
.setIndexPattern(InferenceSecretsIndex.INDEX_PATTERN)
.setPrimaryIndex(InferenceSecretsIndex.INDEX_NAME)
.setDescription("Contains inference service secrets")
.setMappings(InferenceSecretsIndex.mappings())
.setSettings(getSecretsIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setNetNew()
.build()
);
return Stream.of(
List.of(
SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
.setIndexPattern(InferenceIndex.INDEX_PATTERN)
.setAliasName(InferenceIndex.INDEX_ALIAS)
.setPrimaryIndex(InferenceIndex.INDEX_NAME)
.setDescription("Contains inference service and model configuration")
.setMappings(InferenceIndex.mappings())
.setSettings(getIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor))
.build(),
SystemIndexDescriptor.builder()
.setType(SystemIndexDescriptor.Type.INTERNAL_MANAGED)
.setIndexPattern(InferenceSecretsIndex.INDEX_PATTERN)
.setPrimaryIndex(InferenceSecretsIndex.INDEX_NAME)
.setDescription("Contains inference service secrets")
.setMappings(InferenceSecretsIndex.mappings())
.setSettings(getSecretsIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setNetNew()
.build()
),
ccmIndexDescriptor
).flatMap(List::stream).toList();
}

// Overridable for tests
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic.ccm;

import org.elasticsearch.common.util.FeatureFlag;

public class CCMFeatureFlag {

/**
* {@link org.elasticsearch.xpack.inference.services.custom.CustomService} feature flag. When the feature is complete,
* this flag will be removed.
* Enable feature via JVM option: `-Des.inference_api_ccm_feature_flag_enabled=true`.
*/
public static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_api_ccm");

private CCMFeatureFlag() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic.ccm;

import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;

public class CCMIndex {

private CCMIndex() {}

public static final String INDEX_NAME = ".ccm-inference";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also start with an alias, like what we do for InferenceIndex, in case we ever need to migrate to .ccm-inference-00002 or whatever

public static final String INDEX_PATTERN = INDEX_NAME + "*";

// Increment this version number when the mappings change
private static final int INDEX_MAPPING_VERSION = 1;

public static Settings settings() {
return builder().build();
}

// Public to allow tests to create the index with custom settings
public static Settings.Builder builder() {
return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1");
}

public static XContentBuilder mappings() {
try {
return jsonBuilder().startObject()
.startObject(SINGLE_MAPPING_NAME)
.startObject("_meta")
.field(SystemIndexDescriptor.VERSION_META_KEY, INDEX_MAPPING_VERSION)
.endObject()
.field("dynamic", "strict")
.startObject("properties")
.startObject("api_key")
.field("type", "keyword")
.endObject()
.endObject()
.endObject()
.endObject();
} catch (IOException e) {
throw new UncheckedIOException("Failed to build mappings for index " + INDEX_NAME, e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic.ccm;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public record CCMModel(SecureString apiKey) implements Writeable, ToXContentObject {

private static final String API_KEY_FIELD = "api_key";
private static final ConstructingObjectParser<CCMModel, Void> PARSER = new ConstructingObjectParser<>(
CCMModel.class.getSimpleName(),
true,
args -> new CCMModel(new SecureString(((String) args[0]).toCharArray()))
);

static {
PARSER.declareString(constructorArg(), new ParseField(API_KEY_FIELD));
}

public static CCMModel parse(org.elasticsearch.xcontent.XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

public CCMModel {
Objects.requireNonNull(apiKey);
}

public CCMModel(StreamInput in) throws IOException {
this(in.readSecureString());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeSecureString(apiKey);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(API_KEY_FIELD, apiKey.toString());
builder.endObject();
return builder;
}

public static CCMModel fromXContentBytes(BytesReference bytes) throws IOException {
try (var parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, bytes, XContentType.JSON)) {
return parse(parser);
}
}
}
Loading