Skip to content

Commit 98aee7d

Browse files
Even more tests
1 parent 3961e74 commit 98aee7d

File tree

16 files changed

+490
-81
lines changed

16 files changed

+490
-81
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.registry;
8+
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.elasticsearch.action.ActionResponse;
1111
import org.elasticsearch.action.ActionType;
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.Model;
17+
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse;
1718

1819
import java.io.IOException;
1920
import java.util.List;
@@ -41,9 +42,27 @@ public Request(StreamInput in) throws IOException {
4142
models = in.readCollectionAsImmutableList(Model::new);
4243
}
4344

45+
@Override
46+
public void writeTo(StreamOutput out) throws IOException {
47+
super.writeTo(out);
48+
out.writeCollection(models);
49+
}
50+
4451
public List<Model> getModels() {
4552
return models;
4653
}
54+
55+
@Override
56+
public boolean equals(Object o) {
57+
if (o == null || getClass() != o.getClass()) return false;
58+
Request request = (Request) o;
59+
return Objects.equals(models, request.models);
60+
}
61+
62+
@Override
63+
public int hashCode() {
64+
return Objects.hashCode(models);
65+
}
4766
}
4867

4968
public static class Response extends ActionResponse {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.core.Nullable;
14+
import org.elasticsearch.rest.RestStatus;
15+
16+
import java.io.IOException;
17+
import java.util.Objects;
18+
19+
/**
20+
* Response for storing a model in the model registry using the bulk API.
21+
*/
22+
public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) implements Writeable {
23+
24+
public ModelStoreResponse(StreamInput in) throws IOException {
25+
this(in.readString(), RestStatus.readFrom(in), in.readException());
26+
}
27+
28+
public boolean failed() {
29+
return failureCause != null;
30+
}
31+
32+
@Override
33+
public void writeTo(StreamOutput out) throws IOException {
34+
out.writeString(inferenceId);
35+
RestStatus.writeTo(out, status);
36+
out.writeException(failureCause);
37+
}
38+
39+
@Override
40+
public boolean equals(Object o) {
41+
if (o == null || getClass() != o.getClass()) return false;
42+
ModelStoreResponse that = (ModelStoreResponse) o;
43+
return status == that.status && Objects.equals(inferenceId, that.inferenceId)
44+
// Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal
45+
&& Objects.equals(
46+
failureCause == null ? null : failureCause.getMessage(),
47+
that.failureCause == null ? null : that.failureCause.getMessage()
48+
)
49+
&& Objects.equals(
50+
failureCause == null ? null : failureCause.getClass(),
51+
that.failureCause == null ? null : that.failureCause.getClass()
52+
);
53+
}
54+
55+
@Override
56+
public int hashCode() {
57+
return Objects.hash(
58+
inferenceId,
59+
status,
60+
// Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal
61+
failureCause == null ? null : failureCause.getMessage(),
62+
failureCause == null ? null : failureCause.getClass()
63+
);
64+
}
65+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
16+
import org.elasticsearch.inference.EmptySecretSettings;
17+
import org.elasticsearch.inference.EmptyTaskSettings;
18+
import org.elasticsearch.inference.Model;
19+
import org.elasticsearch.inference.ModelConfigurations;
20+
import org.elasticsearch.inference.ModelSecrets;
21+
import org.elasticsearch.inference.ServiceSettings;
22+
import org.elasticsearch.inference.SimilarityMeasure;
23+
import org.elasticsearch.inference.TaskType;
24+
import org.elasticsearch.test.ESTestCase;
25+
import org.elasticsearch.xcontent.ToXContentObject;
26+
import org.elasticsearch.xcontent.XContentBuilder;
27+
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests;
28+
29+
import java.io.IOException;
30+
import java.util.List;
31+
32+
public class ModelTests extends ESTestCase {
33+
public static Model randomModel() {
34+
return new Model(
35+
new ModelConfigurations(
36+
randomAlphaOfLength(6),
37+
randomFrom(TaskType.values()),
38+
randomAlphaOfLength(6),
39+
new TestServiceSettings(
40+
randomAlphaOfLength(10),
41+
randomIntBetween(1, 1024),
42+
randomFrom(SimilarityMeasure.values()),
43+
randomFrom(DenseVectorFieldMapper.ElementType.values())
44+
),
45+
EmptyTaskSettings.INSTANCE,
46+
randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null
47+
),
48+
new ModelSecrets(EmptySecretSettings.INSTANCE)
49+
);
50+
}
51+
52+
public record TestServiceSettings(
53+
String model,
54+
Integer dimensions,
55+
@Nullable SimilarityMeasure similarity,
56+
@Nullable DenseVectorFieldMapper.ElementType elementType
57+
) implements ServiceSettings {
58+
59+
static final String NAME = "test_text_embedding_service_settings";
60+
61+
public TestServiceSettings(StreamInput in) throws IOException {
62+
this(
63+
in.readString(),
64+
in.readInt(),
65+
in.readOptionalEnum(SimilarityMeasure.class),
66+
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class)
67+
);
68+
}
69+
70+
@Override
71+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
72+
builder.startObject();
73+
builder.field("model", model);
74+
builder.field("dimensions", dimensions);
75+
if (similarity != null) {
76+
builder.field("similarity", similarity);
77+
}
78+
if (elementType != null) {
79+
builder.field("element_type", elementType);
80+
}
81+
builder.endObject();
82+
return builder;
83+
}
84+
85+
@Override
86+
public String getWriteableName() {
87+
return NAME;
88+
}
89+
90+
@Override
91+
public TransportVersion getMinimalSupportedVersion() {
92+
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
93+
}
94+
95+
@Override
96+
public void writeTo(StreamOutput out) throws IOException {
97+
out.writeString(model);
98+
out.writeInt(dimensions);
99+
out.writeOptionalEnum(similarity);
100+
out.writeOptionalEnum(elementType);
101+
}
102+
103+
@Override
104+
public ToXContentObject getFilteredXContentObject() {
105+
return this;
106+
}
107+
108+
@Override
109+
public SimilarityMeasure similarity() {
110+
return similarity != null ? similarity : SimilarityMeasure.COSINE;
111+
}
112+
113+
@Override
114+
public DenseVectorFieldMapper.ElementType elementType() {
115+
return elementType != null ? elementType : DenseVectorFieldMapper.ElementType.FLOAT;
116+
}
117+
118+
@Override
119+
public String modelId() {
120+
return model;
121+
}
122+
}
123+
124+
public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
125+
return List.of(
126+
new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new)
127+
);
128+
}
129+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.EmptySecretSettings;
14+
import org.elasticsearch.inference.EmptyTaskSettings;
15+
import org.elasticsearch.inference.SecretSettings;
16+
import org.elasticsearch.inference.TaskSettings;
17+
import org.elasticsearch.xpack.core.XPackClientPlugin;
18+
import org.elasticsearch.xpack.core.inference.ModelTests;
19+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
24+
public class StoreInferenceEndpointsActionRequestTests extends AbstractBWCWireSerializationTestCase<StoreInferenceEndpointsAction.Request> {
25+
26+
@Override
27+
protected StoreInferenceEndpointsAction.Request mutateInstanceForVersion(
28+
StoreInferenceEndpointsAction.Request instance,
29+
TransportVersion version
30+
) {
31+
return instance;
32+
}
33+
34+
@Override
35+
protected Writeable.Reader<StoreInferenceEndpointsAction.Request> instanceReader() {
36+
return StoreInferenceEndpointsAction.Request::new;
37+
}
38+
39+
@Override
40+
protected StoreInferenceEndpointsAction.Request createTestInstance() {
41+
return new StoreInferenceEndpointsAction.Request(randomList(5, ModelTests::randomModel), randomTimeValue());
42+
}
43+
44+
@Override
45+
protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException {
46+
var newModels = new ArrayList<>(instance.getModels());
47+
newModels.add(ModelTests.randomModel());
48+
return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout());
49+
}
50+
51+
@Override
52+
protected NamedWriteableRegistry getNamedWriteableRegistry() {
53+
var namedWriteables = new ArrayList<NamedWriteableRegistry.Entry>();
54+
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));
55+
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new));
56+
namedWriteables.addAll(ModelTests.getNamedWriteables());
57+
namedWriteables.addAll(XPackClientPlugin.getChunkingSettingsNamedWriteables());
58+
59+
return new NamedWriteableRegistry(namedWriteables);
60+
}
61+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.io.stream.Writeable;
12+
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponseTests;
13+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
14+
15+
import java.io.IOException;
16+
import java.util.ArrayList;
17+
18+
public class StoreInferenceEndpointsActionResponseTests extends AbstractBWCWireSerializationTestCase<
19+
StoreInferenceEndpointsAction.Response> {
20+
21+
@Override
22+
protected StoreInferenceEndpointsAction.Response mutateInstanceForVersion(
23+
StoreInferenceEndpointsAction.Response instance,
24+
TransportVersion version
25+
) {
26+
return instance;
27+
}
28+
29+
@Override
30+
protected Writeable.Reader<StoreInferenceEndpointsAction.Response> instanceReader() {
31+
return StoreInferenceEndpointsAction.Response::new;
32+
}
33+
34+
@Override
35+
protected StoreInferenceEndpointsAction.Response createTestInstance() {
36+
return new StoreInferenceEndpointsAction.Response(randomList(5, ModelStoreResponseTests::randomModelStoreResponse));
37+
}
38+
39+
@Override
40+
protected StoreInferenceEndpointsAction.Response mutateInstance(StoreInferenceEndpointsAction.Response instance) throws IOException {
41+
var newResults = new ArrayList<>(instance.getResults());
42+
newResults.add(ModelStoreResponseTests.randomModelStoreResponse());
43+
return new StoreInferenceEndpointsAction.Response(newResults);
44+
}
45+
}

0 commit comments

Comments
 (0)