99
1010package org .elasticsearch .inference ;
1111
12+ import org .elasticsearch .TransportVersion ;
13+ import org .elasticsearch .TransportVersions ;
14+ import org .elasticsearch .cluster .Diff ;
15+ import org .elasticsearch .cluster .SimpleDiffable ;
16+ import org .elasticsearch .common .io .stream .StreamInput ;
17+ import org .elasticsearch .common .io .stream .StreamOutput ;
1218import org .elasticsearch .core .Nullable ;
1319import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
1420import org .elasticsearch .xcontent .ConstructingObjectParser ;
4652 * @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
4753 */
4854public record MinimalServiceSettings (
55+ @ Nullable String service ,
4956 TaskType taskType ,
5057 @ Nullable Integer dimensions ,
5158 @ Nullable SimilarityMeasure similarity ,
5259 @ Nullable ElementType elementType
53- ) implements ToXContentObject {
60+ ) implements ServiceSettings , SimpleDiffable < MinimalServiceSettings > {
5461
62+ public static final String NAME = "minimal_service_settings" ;
63+
64+ public static final String SERVICE_FIELD = "service" ;
5565 public static final String TASK_TYPE_FIELD = "task_type" ;
5666 static final String DIMENSIONS_FIELD = "dimensions" ;
5767 static final String SIMILARITY_FIELD = "similarity" ;
@@ -61,17 +71,20 @@ public record MinimalServiceSettings(
6171 "model_settings" ,
6272 true ,
6373 args -> {
64- TaskType taskType = TaskType .fromString ((String ) args [0 ]);
65- Integer dimensions = (Integer ) args [1 ];
66- SimilarityMeasure similarity = args [2 ] == null ? null : SimilarityMeasure .fromString ((String ) args [2 ]);
67- DenseVectorFieldMapper .ElementType elementType = args [3 ] == null
74+ String service = (String ) args [0 ];
75+ TaskType taskType = TaskType .fromString ((String ) args [1 ]);
76+ Integer dimensions = (Integer ) args [2 ];
77+ SimilarityMeasure similarity = args [3 ] == null ? null : SimilarityMeasure .fromString ((String ) args [3 ]);
78+ DenseVectorFieldMapper .ElementType elementType = args [4 ] == null
6879 ? null
69- : DenseVectorFieldMapper .ElementType .fromString ((String ) args [3 ]);
70- return new MinimalServiceSettings (taskType , dimensions , similarity , elementType );
80+ : DenseVectorFieldMapper .ElementType .fromString ((String ) args [4 ]);
81+ return new MinimalServiceSettings (service , taskType , dimensions , similarity , elementType );
7182 }
7283 );
84+ private static final String UNKNOWN_SERVICE = "_unknown_" ;
7385
7486 static {
87+ PARSER .declareString (ConstructingObjectParser .optionalConstructorArg (), new ParseField (SERVICE_FIELD ));
7588 PARSER .declareString (ConstructingObjectParser .constructorArg (), new ParseField (TASK_TYPE_FIELD ));
7689 PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), new ParseField (DIMENSIONS_FIELD ));
7790 PARSER .declareString (ConstructingObjectParser .optionalConstructorArg (), new ParseField (SIMILARITY_FIELD ));
@@ -82,51 +95,95 @@ public static MinimalServiceSettings parse(XContentParser parser) throws IOExcep
8295 return PARSER .parse (parser , null );
8396 }
8497
85- public static MinimalServiceSettings textEmbedding (int dimensions , SimilarityMeasure similarity , ElementType elementType ) {
86- return new MinimalServiceSettings (TEXT_EMBEDDING , dimensions , similarity , elementType );
98+ public static MinimalServiceSettings textEmbedding (
99+ String serviceName ,
100+ int dimensions ,
101+ SimilarityMeasure similarity ,
102+ ElementType elementType
103+ ) {
104+ return new MinimalServiceSettings (serviceName , TEXT_EMBEDDING , dimensions , similarity , elementType );
105+ }
106+
107+ public static MinimalServiceSettings sparseEmbedding (String serviceName ) {
108+ return new MinimalServiceSettings (serviceName , SPARSE_EMBEDDING , null , null , null );
87109 }
88110
89- public static MinimalServiceSettings sparseEmbedding ( ) {
90- return new MinimalServiceSettings (SPARSE_EMBEDDING , null , null , null );
111+ public static MinimalServiceSettings rerank ( String serviceName ) {
112+ return new MinimalServiceSettings (serviceName , RERANK , null , null , null );
91113 }
92114
93- public static MinimalServiceSettings rerank ( ) {
94- return new MinimalServiceSettings (RERANK , null , null , null );
115+ public static MinimalServiceSettings completion ( String serviceName ) {
116+ return new MinimalServiceSettings (serviceName , COMPLETION , null , null , null );
95117 }
96118
97- public static MinimalServiceSettings completion ( ) {
98- return new MinimalServiceSettings (COMPLETION , null , null , null );
119+ public static MinimalServiceSettings chatCompletion ( String serviceName ) {
120+ return new MinimalServiceSettings (serviceName , CHAT_COMPLETION , null , null , null );
99121 }
100122
101- public static MinimalServiceSettings chatCompletion () {
102- return new MinimalServiceSettings (CHAT_COMPLETION , null , null , null );
123+ public MinimalServiceSettings {
124+ Objects .requireNonNull (taskType , "task type must not be null" );
125+ validate (taskType , dimensions , similarity , elementType );
103126 }
104127
105128 public MinimalServiceSettings (Model model ) {
106129 this (
130+ model .getConfigurations ().getService (),
107131 model .getTaskType (),
108132 model .getServiceSettings ().dimensions (),
109133 model .getServiceSettings ().similarity (),
110134 model .getServiceSettings ().elementType ()
111135 );
112136 }
113137
114- public MinimalServiceSettings (
115- TaskType taskType ,
116- @ Nullable Integer dimensions ,
117- @ Nullable SimilarityMeasure similarity ,
118- @ Nullable ElementType elementType
119- ) {
120- this .taskType = Objects .requireNonNull (taskType , "task type must not be null" );
121- this .dimensions = dimensions ;
122- this .similarity = similarity ;
123- this .elementType = elementType ;
124- validate ();
138+ public MinimalServiceSettings (StreamInput in ) throws IOException {
139+ this (
140+ in .readOptionalString (),
141+ TaskType .fromStream (in ),
142+ in .readOptionalInt (),
143+ in .readOptionalEnum (SimilarityMeasure .class ),
144+ in .readOptionalEnum (ElementType .class )
145+ );
146+ }
147+
148+ @ Override
149+ public void writeTo (StreamOutput out ) throws IOException {
150+ out .writeOptionalString (service );
151+ taskType .writeTo (out );
152+ out .writeOptionalInt (dimensions );
153+ out .writeOptionalEnum (similarity );
154+ out .writeOptionalEnum (elementType );
155+ }
156+
157+ @ Override
158+ public String getWriteableName () {
159+ return NAME ;
160+ }
161+
162+ @ Override
163+ public TransportVersion getMinimalSupportedVersion () {
164+ return TransportVersions .INFERENCE_MODEL_REGISTRY_METADATA ;
165+ }
166+
167+ @ Override
168+ public ToXContentObject getFilteredXContentObject () {
169+ return this ::toXContent ;
170+ }
171+
172+ @ Override
173+ public String modelId () {
174+ return null ;
175+ }
176+
177+ public static Diff <MinimalServiceSettings > readDiffFrom (StreamInput in ) throws IOException {
178+ return SimpleDiffable .readDiffFrom (MinimalServiceSettings ::new , in );
125179 }
126180
127181 @ Override
128182 public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
129183 builder .startObject ();
184+ if (service != null ) {
185+ builder .field (SERVICE_FIELD , service );
186+ }
130187 builder .field (TASK_TYPE_FIELD , taskType .toString ());
131188 if (dimensions != null ) {
132189 builder .field (DIMENSIONS_FIELD , dimensions );
@@ -143,7 +200,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
143200 @ Override
144201 public String toString () {
145202 final StringBuilder sb = new StringBuilder ();
146- sb .append ("task_type=" ).append (taskType );
203+ sb .append ("service=" ).append (service );
204+ sb .append (", task_type=" ).append (taskType );
147205 if (dimensions != null ) {
148206 sb .append (", dimensions=" ).append (dimensions );
149207 }
@@ -156,31 +214,46 @@ public String toString() {
156214 return sb .toString ();
157215 }
158216
159- private void validate () {
217+ private static void validate (TaskType taskType , Integer dimensions , SimilarityMeasure similarity , ElementType elementType ) {
160218 switch (taskType ) {
161219 case TEXT_EMBEDDING :
162- validateFieldPresent (DIMENSIONS_FIELD , dimensions );
163- validateFieldPresent (SIMILARITY_FIELD , similarity );
164- validateFieldPresent (ELEMENT_TYPE_FIELD , elementType );
220+ validateFieldPresent (DIMENSIONS_FIELD , dimensions , taskType );
221+ validateFieldPresent (SIMILARITY_FIELD , similarity , taskType );
222+ validateFieldPresent (ELEMENT_TYPE_FIELD , elementType , taskType );
165223 break ;
166224
167225 default :
168- validateFieldNotPresent (DIMENSIONS_FIELD , dimensions );
169- validateFieldNotPresent (SIMILARITY_FIELD , similarity );
170- validateFieldNotPresent (ELEMENT_TYPE_FIELD , elementType );
226+ validateFieldNotPresent (DIMENSIONS_FIELD , dimensions , taskType );
227+ validateFieldNotPresent (SIMILARITY_FIELD , similarity , taskType );
228+ validateFieldNotPresent (ELEMENT_TYPE_FIELD , elementType , taskType );
171229 break ;
172230 }
173231 }
174232
175- private void validateFieldPresent (String field , Object fieldValue ) {
233+ private static void validateFieldPresent (String field , Object fieldValue , TaskType taskType ) {
176234 if (fieldValue == null ) {
177235 throw new IllegalArgumentException ("required [" + field + "] field is missing for task_type [" + taskType .name () + "]" );
178236 }
179237 }
180238
181- private void validateFieldNotPresent (String field , Object fieldValue ) {
239+ private static void validateFieldNotPresent (String field , Object fieldValue , TaskType taskType ) {
182240 if (fieldValue != null ) {
183241 throw new IllegalArgumentException ("[" + field + "] is not allowed for task_type [" + taskType .name () + "]" );
184242 }
185243 }
244+
245+ public ModelConfigurations toModelConfigurations (String inferenceEntityId ) {
246+ return new ModelConfigurations (inferenceEntityId , taskType , service == null ? UNKNOWN_SERVICE : service , this );
247+ }
248+
249+ /**
250+ * Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
251+ */
252+ public boolean canMergeWith (MinimalServiceSettings other ) {
253+ return taskType == other .taskType
254+ && Objects .equals (dimensions , other .dimensions )
255+ && similarity == other .similarity
256+ && elementType == other .elementType
257+ && (service == null || service .equals (other .service ));
258+ }
186259}
0 commit comments