@@ -60,13 +60,15 @@ public static class Request extends ActionRequest {
6060 public static final ParseField INPUT = new ParseField ("input" );
6161 public static final ParseField TASK_SETTINGS = new ParseField ("task_settings" );
6262 public static final ParseField QUERY = new ParseField ("query" );
63+ public static final ParseField CHUNKING_ENABLED = new ParseField ("chunking_enabled" );
6364 public static final ParseField TIMEOUT = new ParseField ("timeout" );
6465
6566 static final ObjectParser <Request .Builder , Void > PARSER = new ObjectParser <>(NAME , Request .Builder ::new );
6667 static {
6768 PARSER .declareStringArray (Request .Builder ::setInput , INPUT );
6869 PARSER .declareObject (Request .Builder ::setTaskSettings , (p , c ) -> p .mapOrdered (), TASK_SETTINGS );
6970 PARSER .declareString (Request .Builder ::setQuery , QUERY );
71+ PARSER .declareBoolean (Request .Builder ::setChunkingEnabled , CHUNKING_ENABLED );
7072 PARSER .declareString (Builder ::setInferenceTimeout , TIMEOUT );
7173 }
7274
@@ -93,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
9395 private final InputType inputType ;
9496 private final TimeValue inferenceTimeout ;
9597 private final boolean stream ;
98+ private final boolean chunkingEnabled ;
9699
97100 public Request (
98101 TaskType taskType ,
@@ -112,6 +115,29 @@ public Request(
112115 this .inputType = inputType ;
113116 this .inferenceTimeout = inferenceTimeout ;
114117 this .stream = stream ;
118+ this .chunkingEnabled = false ;
119+ }
120+
121+ public Request (
122+ TaskType taskType ,
123+ String inferenceEntityId ,
124+ String query ,
125+ List <String > input ,
126+ Map <String , Object > taskSettings ,
127+ InputType inputType ,
128+ TimeValue inferenceTimeout ,
129+ boolean stream ,
130+ boolean chunkingEnabled
131+ ) {
132+ this .taskType = taskType ;
133+ this .inferenceEntityId = inferenceEntityId ;
134+ this .query = query ;
135+ this .input = input ;
136+ this .taskSettings = taskSettings ;
137+ this .inputType = inputType ;
138+ this .inferenceTimeout = inferenceTimeout ;
139+ this .stream = stream ;
140+ this .chunkingEnabled = chunkingEnabled ;
115141 }
116142
117143 public Request (StreamInput in ) throws IOException {
@@ -138,6 +164,12 @@ public Request(StreamInput in) throws IOException {
138164 this .inferenceTimeout = DEFAULT_TIMEOUT ;
139165 }
140166
167+ if (in .getTransportVersion ().onOrAfter (TransportVersions .CHUNKING_ENABLED_PERFORM_INFERENCE )) {
168+ this .chunkingEnabled = in .readBoolean ();
169+ } else {
170+ this .chunkingEnabled = false ;
171+ }
172+
141173 // streaming is not supported yet for transport traffic
142174 this .stream = false ;
143175 }
@@ -174,6 +206,10 @@ public boolean isStreaming() {
174206 return stream ;
175207 }
176208
209+ public boolean isChunkingEnabled () {
210+ return chunkingEnabled ;
211+ }
212+
177213 @ Override
178214 public ActionRequestValidationException validate () {
179215 if (input == null ) {
@@ -201,6 +237,12 @@ public ActionRequestValidationException validate() {
201237 }
202238 }
203239
240+ if (chunkingEnabled && ((taskType .equals (TaskType .SPARSE_EMBEDDING ) || taskType .equals (TaskType .TEXT_EMBEDDING )) == false )) {
241+ var e = new ActionRequestValidationException ();
242+ e .addValidationError (format ("Chunking is only supported for embedding task types." ));
243+ return e ;
244+ }
245+
204246 return null ;
205247 }
206248
@@ -224,6 +266,10 @@ public void writeTo(StreamOutput out) throws IOException {
224266 out .writeOptionalString (query );
225267 out .writeTimeValue (inferenceTimeout );
226268 }
269+
270+ if (out .getTransportVersion ().onOrAfter (TransportVersions .CHUNKING_ENABLED_PERFORM_INFERENCE )) {
271+ out .writeBoolean (chunkingEnabled );
272+ }
227273 }
228274
229275 // default for easier testing
@@ -250,12 +296,13 @@ public boolean equals(Object o) {
250296 && Objects .equals (taskSettings , request .taskSettings )
251297 && Objects .equals (inputType , request .inputType )
252298 && Objects .equals (query , request .query )
253- && Objects .equals (inferenceTimeout , request .inferenceTimeout );
299+ && Objects .equals (inferenceTimeout , request .inferenceTimeout )
300+ && Objects .equals (chunkingEnabled , request .chunkingEnabled );
254301 }
255302
256303 @ Override
257304 public int hashCode () {
258- return Objects .hash (taskType , inferenceEntityId , input , taskSettings , inputType , query , inferenceTimeout );
305+ return Objects .hash (taskType , inferenceEntityId , input , taskSettings , inputType , query , chunkingEnabled , inferenceTimeout );
259306 }
260307
261308 public static class Builder {
@@ -266,6 +313,7 @@ public static class Builder {
266313 private InputType inputType = InputType .UNSPECIFIED ;
267314 private Map <String , Object > taskSettings = Map .of ();
268315 private String query ;
316+ private boolean chunkingEnabled = false ;
269317 private TimeValue timeout = DEFAULT_TIMEOUT ;
270318 private boolean stream = false ;
271319
@@ -291,6 +339,11 @@ public Builder setQuery(String query) {
291339 return this ;
292340 }
293341
342+ public Builder setChunkingEnabled (boolean chunkingEnabled ) {
343+ this .chunkingEnabled = chunkingEnabled ;
344+ return this ;
345+ }
346+
294347 public Builder setInputType (InputType inputType ) {
295348 this .inputType = inputType ;
296349 return this ;
@@ -316,7 +369,7 @@ public Builder setStream(boolean stream) {
316369 }
317370
318371 public Request build () {
319- return new Request (taskType , inferenceEntityId , query , input , taskSettings , inputType , timeout , stream );
372+ return new Request (taskType , inferenceEntityId , query , input , taskSettings , inputType , timeout , stream , chunkingEnabled );
320373 }
321374 }
322375
@@ -335,6 +388,8 @@ public String toString() {
335388 + this .getInputType ()
336389 + ", timeout="
337390 + this .getInferenceTimeout ()
391+ + ", chunking_enabled="
392+ + this .isChunkingEnabled ()
338393 + ")" ;
339394 }
340395 }
0 commit comments