Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ public Integer getMaxInputTokens() {
return textEmbeddingSettings.maxInputTokens;
}

TextEmbeddingSettings getTextEmbeddingSettings() {
return textEmbeddingSettings;
}

public String getUrl() {
return url;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

String getCompletionResultPath() {
return completionResultPath;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

String getRelevanceScorePath() {
return relevanceScorePath;
}

String getRerankIndexPath() {
return rerankIndexPath;
}

String getDocumentTextPath() {
return documentTextPath;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

String getTokenPath() {
return tokenPath;
}

String getWeightPath() {
return weightPath;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ private static CommonFields commonFieldsFromMap(Map<String, Object> map, Validat
private final DenseVectorFieldMapper.ElementType elementType;

CustomElandInternalTextEmbeddingServiceSettings(
Integer numAllocations,
@Nullable Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
@Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings,
@Nullable String deploymentId,
Integer dimensions,
@Nullable Integer dimensions,
SimilarityMeasure similarityMeasure,
DenseVectorFieldMapper.ElementType elementType
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap(
}

public ElasticsearchInternalServiceSettings(
Integer numAllocations,
@Nullable Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
@Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings,
@Nullable String deploymentId
) {
this.numAllocations = numAllocations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ protected ModelSecrets createTestInstance() {

@Override
protected ModelSecrets mutateInstance(ModelSecrets instance) {
return randomValueOtherThan(instance, ModelSecretsTests::createRandomInstance);
return new ModelSecrets(randomValueOtherThan(instance.getSecretSettings(), ModelSecretsTests::randomSecretSettings));
}

public record FakeSecretSettings(String apiKey) implements SecretSettings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ protected GetRerankerWindowSizeAction.Request createTestInstance() {

@Override
protected GetRerankerWindowSizeAction.Request mutateInstance(GetRerankerWindowSizeAction.Request instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
return new GetRerankerWindowSizeAction.Request(randomValueOtherThan(instance.getInferenceEntityId(), () -> randomAlphaOfLength(8)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;

import java.io.IOException;
Expand All @@ -26,6 +27,6 @@ protected GetRerankerWindowSizeAction.Response createTestInstance() {

@Override
protected GetRerankerWindowSizeAction.Response mutateInstance(GetRerankerWindowSizeAction.Response instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
return new GetRerankerWindowSizeAction.Response(randomValueOtherThan(instance.getWindowSize(), ESTestCase::randomNonNegativeInt));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
Expand Down Expand Up @@ -38,16 +39,18 @@ protected Writeable.Reader<InferenceAction.Response> instanceReader() {

@Override
protected InferenceAction.Response createTestInstance() {
var result = randomBoolean()
? DenseEmbeddingFloatResultsTests.createRandomResults()
: SparseEmbeddingResultsTests.createRandomResults();
return new InferenceAction.Response(getRandomResults());
}

return new InferenceAction.Response(result);
private InferenceServiceResults getRandomResults() {
return randomBoolean() ? DenseEmbeddingFloatResultsTests.createRandomResults() : SparseEmbeddingResultsTests.createRandomResults();
}

@Override
protected InferenceAction.Response mutateInstance(InferenceAction.Response instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
var originalResults = instance.getResults();

return new InferenceAction.Response(randomValueOtherThan(originalResults, this::getRandomResults));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
package org.elasticsearch.xpack.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
Expand Down Expand Up @@ -39,7 +42,20 @@ protected PutInferenceModelAction.Request createTestInstance() {

@Override
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
return randomValueOtherThan(instance, this::createTestInstance);
TaskType taskType = instance.getTaskType();
String inferenceId = instance.getInferenceEntityId();
BytesReference content = instance.getContent();
XContentType contentType = instance.getContentType();
TimeValue timeout = instance.getTimeout();
switch (randomInt(4)) {
case 0 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(TaskType.values()));
case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> randomAlphaOfLength(6));
case 2 -> content = randomValueOtherThan(content, () -> randomBytesReference(50));
case 3 -> contentType = randomValueOtherThan(contentType, () -> randomFrom(XContentType.values()));
case 4 -> timeout = randomValueOtherThan(timeout, ESTestCase::randomTimeValue);
default -> throw new AssertionError("Illegal randomisation branch");
}
return new PutInferenceModelAction.Request(taskType, inferenceId, content, contentType, timeout);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.XPackClientPlugin;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
Expand All @@ -27,10 +28,8 @@ protected PutInferenceModelAction.Response createTestInstance() {

@Override
protected PutInferenceModelAction.Response mutateInstance(PutInferenceModelAction.Response instance) {
return randomValueOtherThan(instance, () -> {
var mutatedModel = ModelConfigurationsTests.mutateTestInstance(instance.getModel());
return new PutInferenceModelAction.Response(mutatedModel);
});
ModelConfigurations newModel = randomValueOtherThan(instance.getModel(), ModelConfigurationsTests::createRandomInstance);
return new PutInferenceModelAction.Response(newModel);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,19 @@ protected UpdateInferenceModelAction.Request createTestInstance() {

@Override
protected UpdateInferenceModelAction.Request mutateInstance(UpdateInferenceModelAction.Request instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
var inferenceId = instance.getInferenceEntityId();
var content = instance.getContent();
var contentType = instance.getContentType();
var taskType = instance.getTaskType();
switch (randomInt(3)) {
case 0 -> inferenceId = randomValueOtherThan(inferenceId, () -> randomAlphaOfLength(5));
case 1 -> content = randomValueOtherThan(content, () -> randomBytesReference(50));
case 2 -> contentType = randomValueOtherThan(contentType, () -> randomFrom(XContentType.values()));
case 3 -> taskType = randomValueOtherThan(taskType, () -> randomFrom(TaskType.values()));
default -> throw new AssertionError("Illegal randomisation branch");
}

return new UpdateInferenceModelAction.Request(inferenceId, content, contentType, taskType, randomTimeValue());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.XPackClientPlugin;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
Expand All @@ -32,7 +33,8 @@ protected UpdateInferenceModelAction.Response createTestInstance() {

@Override
protected UpdateInferenceModelAction.Response mutateInstance(UpdateInferenceModelAction.Response instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
ModelConfigurations newModel = randomValueOtherThan(instance.getModel(), ModelConfigurationsTests::createRandomInstance);
return new UpdateInferenceModelAction.Response(newModel);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,19 @@ protected AwsSecretSettings createTestInstance() {

@Override
protected AwsSecretSettings mutateInstance(AwsSecretSettings instance) throws IOException {
return randomValueOtherThan(instance, AwsSecretSettingsTests::createRandom);
if (randomBoolean()) {
var accessKey = randomValueOtherThan(instance.accessKey().toString(), () -> randomAlphaOfLength(10));
return new AwsSecretSettings(new SecureString(accessKey.toCharArray()), instance.secretKey());
} else {
var secretKey = randomValueOtherThan(instance.secretKey().toString(), () -> randomAlphaOfLength(10));
return new AwsSecretSettings(instance.accessKey(), new SecureString(secretKey.toCharArray()));
}
}

private static AwsSecretSettings createRandom() {
return new AwsSecretSettings(new SecureString(randomAlphaOfLength(10)), new SecureString(randomAlphaOfLength(10)));
return new AwsSecretSettings(
new SecureString(randomAlphaOfLength(10).toCharArray()),
new SecureString(randomAlphaOfLength(10).toCharArray())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,17 @@ protected Ai21ChatCompletionServiceSettings createTestInstance() {

@Override
protected Ai21ChatCompletionServiceSettings mutateInstance(Ai21ChatCompletionServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, Ai21ChatCompletionServiceSettingsTests::createRandom);
if (randomBoolean()) {
return new Ai21ChatCompletionServiceSettings(
randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(8)),
instance.rateLimitSettings()
);
} else {
return new Ai21ChatCompletionServiceSettings(
instance.modelId(),
randomValueOtherThan(instance.rateLimitSettings(), RateLimitSettingsTests::createRandom)
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,18 @@ protected AmazonBedrockChatCompletionServiceSettings createTestInstance() {
@Override
protected AmazonBedrockChatCompletionServiceSettings mutateInstance(AmazonBedrockChatCompletionServiceSettings instance)
throws IOException {
return randomValueOtherThan(instance, AmazonBedrockChatCompletionServiceSettingsTests::createRandom);
var region = instance.region();
var modelId = instance.modelId();
var provider = instance.provider();
var rateLimitSettings = instance.rateLimitSettings();
switch (randomInt(3)) {
case 0 -> region = randomValueOtherThan(region, () -> randomAlphaOfLength(10));
case 1 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLength(10));
case 2 -> provider = randomValueOtherThan(provider, () -> randomFrom(AmazonBedrockProvider.values()));
case 3 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom);
default -> throw new AssertionError("Illegal randomisation branch");
}
return new AmazonBedrockChatCompletionServiceSettings(region, modelId, provider, rateLimitSettings);
}

private static AmazonBedrockChatCompletionServiceSettings createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
Expand Down Expand Up @@ -280,15 +281,27 @@ protected AmazonBedrockChatCompletionTaskSettings createTestInstance() {

@Override
protected AmazonBedrockChatCompletionTaskSettings mutateInstance(AmazonBedrockChatCompletionTaskSettings instance) throws IOException {
return randomValueOtherThan(instance, AmazonBedrockChatCompletionTaskSettingsTests::createRandom);
var temperature = instance.temperature();
var topP = instance.topP();
var topK = instance.topK();
var maxNewTokens = instance.maxNewTokens();
switch (randomInt(3)) {
case 0 -> temperature = randomValueOtherThan(temperature, ESTestCase::randomOptionalDouble);
case 1 -> topP = randomValueOtherThan(topP, ESTestCase::randomOptionalDouble);
case 2 -> topK = randomValueOtherThan(topK, ESTestCase::randomOptionalDouble);
case 3 -> maxNewTokens = randomValueOtherThan(maxNewTokens, ESTestCase::randomNonNegativeIntOrNull);
default -> throw new AssertionError("Illegal randomisation branch");
}
return new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens);
}

private static AmazonBedrockChatCompletionTaskSettings createRandom() {
return new AmazonBedrockChatCompletionTaskSettings(
randomFrom(new Double[] { null, randomDouble() }),
randomFrom(new Double[] { null, randomDouble() }),
randomFrom(new Double[] { null, randomDouble() }),
randomFrom(new Integer[] { null, randomNonNegativeInt() })
randomOptionalDouble(),
randomOptionalDouble(),
randomOptionalDouble(),
randomNonNegativeIntOrNull()
);
}

}
Loading