1313import org .elasticsearch .common .ValidationException ;
1414import org .elasticsearch .common .io .stream .StreamInput ;
1515import org .elasticsearch .common .io .stream .StreamOutput ;
16+ import org .elasticsearch .common .io .stream .Writeable ;
1617import org .elasticsearch .core .Nullable ;
1718import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
1819import org .elasticsearch .inference .ModelConfigurations ;
1920import org .elasticsearch .inference .ServiceSettings ;
2021import org .elasticsearch .inference .SimilarityMeasure ;
2122import org .elasticsearch .inference .TaskType ;
23+ import org .elasticsearch .xcontent .ToXContentFragment ;
2224import org .elasticsearch .xcontent .ToXContentObject ;
2325import org .elasticsearch .xcontent .XContentBuilder ;
2426import org .elasticsearch .xpack .inference .services .ConfigurationParseContext ;
@@ -66,9 +68,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
6668 public static CustomServiceSettings fromMap (Map <String , Object > map , ConfigurationParseContext context , TaskType taskType ) {
6769 ValidationException validationException = new ValidationException ();
6870
69- SimilarityMeasure similarity = extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
70- Integer dims = removeAsType (map , DIMENSIONS , Integer .class );
71- Integer maxInputTokens = removeAsType (map , MAX_INPUT_TOKENS , Integer .class );
71+ var textEmbeddingSettings = TextEmbeddingSettings .fromMap (map , taskType , validationException );
7272
7373 String url = extractRequiredString (map , URL , ModelConfigurations .SERVICE_SETTINGS , validationException );
7474
@@ -134,9 +134,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
134134 }
135135
136136 return new CustomServiceSettings (
137- similarity ,
138- dims ,
139- maxInputTokens ,
137+ textEmbeddingSettings ,
140138 url ,
141139 stringHeaders ,
142140 queryParams ,
@@ -147,9 +145,59 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
147145 );
148146 }
149147
150- private final SimilarityMeasure similarity ;
151- private final Integer dimensions ;
152- private final Integer maxInputTokens ;
148+ public record TextEmbeddingSettings (
149+ @ Nullable SimilarityMeasure similarityMeasure ,
150+ @ Nullable Integer dimensions ,
151+ @ Nullable Integer maxInputTokens ,
152+ @ Nullable DenseVectorFieldMapper .ElementType elementType
153+ ) implements ToXContentFragment , Writeable {
154+
155+ public static final TextEmbeddingSettings EMPTY = new TextEmbeddingSettings (null , null , null , null );
156+
157+ public static TextEmbeddingSettings fromMap (Map <String , Object > map , TaskType taskType , ValidationException validationException ) {
158+ if (taskType != TaskType .TEXT_EMBEDDING ) {
159+ return EMPTY ;
160+ }
161+
162+ SimilarityMeasure similarity = extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
163+ Integer dims = removeAsType (map , DIMENSIONS , Integer .class );
164+ Integer maxInputTokens = removeAsType (map , MAX_INPUT_TOKENS , Integer .class );
165+ return new TextEmbeddingSettings (similarity , dims , maxInputTokens , DenseVectorFieldMapper .ElementType .FLOAT );
166+ }
167+
168+ public TextEmbeddingSettings (StreamInput in ) throws IOException {
169+ this (
170+ in .readOptionalEnum (SimilarityMeasure .class ),
171+ in .readOptionalVInt (),
172+ in .readOptionalVInt (),
173+ in .readOptionalEnum (DenseVectorFieldMapper .ElementType .class )
174+ );
175+ }
176+
177+ @ Override
178+ public void writeTo (StreamOutput out ) throws IOException {
179+ out .writeOptionalEnum (similarityMeasure );
180+ out .writeOptionalVInt (dimensions );
181+ out .writeOptionalVInt (maxInputTokens );
182+ out .writeOptionalEnum (elementType );
183+ }
184+
185+ @ Override
186+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
187+ if (similarityMeasure != null ) {
188+ builder .field (SIMILARITY , similarityMeasure );
189+ }
190+ if (dimensions != null ) {
191+ builder .field (DIMENSIONS , dimensions );
192+ }
193+ if (maxInputTokens != null ) {
194+ builder .field (MAX_INPUT_TOKENS , maxInputTokens );
195+ }
196+ return builder ;
197+ }
198+ }
199+
200+ private final TextEmbeddingSettings textEmbeddingSettings ;
153201 private final String url ;
154202 private final Map <String , String > headers ;
155203 private final QueryParameters queryParameters ;
@@ -159,9 +207,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
159207 private final ErrorResponseParser errorParser ;
160208
161209 public CustomServiceSettings (
162- @ Nullable SimilarityMeasure similarity ,
163- @ Nullable Integer dimensions ,
164- @ Nullable Integer maxInputTokens ,
210+ @ Nullable TextEmbeddingSettings textEmbeddingSettings ,
165211 String url ,
166212 @ Nullable Map <String , String > headers ,
167213 @ Nullable QueryParameters queryParameters ,
@@ -170,9 +216,7 @@ public CustomServiceSettings(
170216 @ Nullable RateLimitSettings rateLimitSettings ,
171217 ErrorResponseParser errorParser
172218 ) {
173- this .similarity = similarity ;
174- this .dimensions = dimensions ;
175- this .maxInputTokens = maxInputTokens ;
219+ this .textEmbeddingSettings = textEmbeddingSettings == null ? TextEmbeddingSettings .EMPTY : textEmbeddingSettings ;
176220 this .url = Objects .requireNonNull (url );
177221 this .headers = Collections .unmodifiableMap (Objects .requireNonNullElse (headers , Map .of ()));
178222 this .queryParameters = Objects .requireNonNullElse (queryParameters , QueryParameters .EMPTY );
@@ -183,9 +227,7 @@ public CustomServiceSettings(
183227 }
184228
185229 public CustomServiceSettings (StreamInput in ) throws IOException {
186- similarity = in .readOptionalEnum (SimilarityMeasure .class );
187- dimensions = in .readOptionalVInt ();
188- maxInputTokens = in .readOptionalVInt ();
230+ textEmbeddingSettings = new TextEmbeddingSettings (in );
189231 url = in .readString ();
190232 headers = in .readImmutableMap (StreamInput ::readString );
191233 queryParameters = new QueryParameters (in );
@@ -203,21 +245,21 @@ public String modelId() {
203245
204246 @ Override
205247 public SimilarityMeasure similarity () {
206- return similarity ;
248+ return textEmbeddingSettings . similarityMeasure ;
207249 }
208250
209251 @ Override
210252 public Integer dimensions () {
211- return dimensions ;
253+ return textEmbeddingSettings . dimensions ;
212254 }
213255
214256 @ Override
215257 public DenseVectorFieldMapper .ElementType elementType () {
216- return DenseVectorFieldMapper . ElementType . FLOAT ;
258+ return textEmbeddingSettings . elementType ;
217259 }
218260
219261 public Integer getMaxInputTokens () {
220- return maxInputTokens ;
262+ return textEmbeddingSettings . maxInputTokens ;
221263 }
222264
223265 public String getUrl () {
@@ -270,15 +312,7 @@ public XContentBuilder toXContentFragment(XContentBuilder builder, Params params
270312
271313 @ Override
272314 public XContentBuilder toXContentFragmentOfExposedFields (XContentBuilder builder , Params params ) throws IOException {
273- if (similarity != null ) {
274- builder .field (SIMILARITY , similarity );
275- }
276- if (dimensions != null ) {
277- builder .field (DIMENSIONS , dimensions );
278- }
279- if (maxInputTokens != null ) {
280- builder .field (MAX_INPUT_TOKENS , maxInputTokens );
281- }
315+ textEmbeddingSettings .toXContent (builder , params );
282316 builder .field (URL , url );
283317
284318 if (headers .isEmpty () == false ) {
@@ -317,9 +351,7 @@ public TransportVersion getMinimalSupportedVersion() {
317351
318352 @ Override
319353 public void writeTo (StreamOutput out ) throws IOException {
320- out .writeOptionalEnum (similarity );
321- out .writeOptionalVInt (dimensions );
322- out .writeOptionalVInt (maxInputTokens );
354+ textEmbeddingSettings .writeTo (out );
323355 out .writeString (url );
324356 out .writeMap (headers , StreamOutput ::writeString , StreamOutput ::writeString );
325357 queryParameters .writeTo (out );
@@ -334,9 +366,7 @@ public boolean equals(Object o) {
334366 if (this == o ) return true ;
335367 if (o == null || getClass () != o .getClass ()) return false ;
336368 CustomServiceSettings that = (CustomServiceSettings ) o ;
337- return Objects .equals (similarity , that .similarity )
338- && Objects .equals (dimensions , that .dimensions )
339- && Objects .equals (maxInputTokens , that .maxInputTokens )
369+ return Objects .equals (textEmbeddingSettings , that .textEmbeddingSettings )
340370 && Objects .equals (url , that .url )
341371 && Objects .equals (headers , that .headers )
342372 && Objects .equals (queryParameters , that .queryParameters )
@@ -349,9 +379,7 @@ public boolean equals(Object o) {
349379 @ Override
350380 public int hashCode () {
351381 return Objects .hash (
352- similarity ,
353- dimensions ,
354- maxInputTokens ,
382+ textEmbeddingSettings ,
355383 url ,
356384 headers ,
357385 queryParameters ,
0 commit comments