Skip to content

Commit 2a9330b

Browse files
Adding more tests and cleaning up
1 parent 12d271e commit 2a9330b

File tree

8 files changed

+176
-76
lines changed

8 files changed

+176
-76
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public CustomModel(
4646
inferenceId,
4747
taskType,
4848
service,
49-
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
49+
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
5050
CustomTaskSettings.fromMap(taskSettings),
5151
CustomSecretSettings.fromMap(secrets)
5252
);
@@ -66,7 +66,7 @@ public CustomModel(
6666
inferenceId,
6767
taskType,
6868
service,
69-
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
69+
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
7070
CustomTaskSettings.fromMap(taskSettings),
7171
CustomSecretSettings.fromMap(secrets),
7272
chunkingSettings

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
333333
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
334334

335335
return new CustomServiceSettings(
336-
new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens(), null),
336+
new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens()),
337337
serviceSettings.getUrl(),
338338
serviceSettings.getHeaders(),
339339
serviceSettings.getQueryParameters(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
6666
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
6767
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;
6868

69-
public static CustomServiceSettings fromMap(
70-
Map<String, Object> map,
71-
ConfigurationParseContext context,
72-
TaskType taskType,
73-
String inferenceId
74-
) {
69+
public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) {
7570
ValidationException validationException = new ValidationException();
7671

7772
var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException);
@@ -140,14 +135,9 @@ public static CustomServiceSettings fromMap(
140135
public static class TextEmbeddingSettings implements ToXContentFragment, Writeable {
141136

142137
// This specifies float for the element type but null for all other settings
143-
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(
144-
null,
145-
null,
146-
null,
147-
CustomServiceEmbeddingType.FLOAT
148-
);
138+
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null);
149139
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
150-
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
140+
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null);
151141

152142
public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) {
153143
if (taskType != TaskType.TEXT_EMBEDDING) {
@@ -157,7 +147,7 @@ public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType ta
157147
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
158148
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
159149
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
160-
return new TextEmbeddingSettings(similarity, dims, maxInputTokens, null);
150+
return new TextEmbeddingSettings(similarity, dims, maxInputTokens);
161151
}
162152

163153
private final SimilarityMeasure similarityMeasure;
@@ -167,8 +157,7 @@ public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType ta
167157
public TextEmbeddingSettings(
168158
@Nullable SimilarityMeasure similarityMeasure,
169159
@Nullable Integer dimensions,
170-
@Nullable Integer maxInputTokens,
171-
@Nullable CustomServiceEmbeddingType embeddingType
160+
@Nullable Integer maxInputTokens
172161
) {
173162
this.similarityMeasure = similarityMeasure;
174163
this.dimensions = dimensions;
@@ -331,7 +320,12 @@ public Integer dimensions() {
331320

332321
@Override
333322
public DenseVectorFieldMapper.ElementType elementType() {
334-
return responseJsonParser.getEmbeddingType().toElementType();
323+
var embeddingType = responseJsonParser.getEmbeddingType();
324+
if (embeddingType != null) {
325+
return embeddingType.toElementType();
326+
}
327+
328+
return null;
335329
}
336330

337331
public Integer getMaxInputTokens() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
106106
return builder;
107107
}
108108

109-
// For testing
109+
// Default for testing
110110
String getTextEmbeddingsPath() {
111111
return textEmbeddingsPath;
112112
}
113113

114+
@Override
114115
public CustomServiceEmbeddingType getEmbeddingType() {
115116
return embeddingType;
116117
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
110110
String requestContentString = "\"input\":\"${input}\"";
111111

112112
CustomServiceSettings serviceSettings = new CustomServiceSettings(
113-
new CustomServiceSettings.TextEmbeddingSettings(
114-
SimilarityMeasure.DOT_PRODUCT,
115-
dims,
116-
maxInputTokens,
117-
CustomServiceEmbeddingType.FLOAT
118-
),
113+
new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens),
119114
url,
120115
headers,
121116
QueryParameters.EMPTY,

0 commit comments

Comments
 (0)