@@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
6060 public static final ParseField INPUT_TYPE = new ParseField ("input_type" );
6161 public static final ParseField TASK_SETTINGS = new ParseField ("task_settings" );
6262 public static final ParseField QUERY = new ParseField ("query" );
63+ public static final ParseField RETURN_DOCUMENTS = new ParseField ("return_documents" );
64+ public static final ParseField TOP_N = new ParseField ("top_n" );
6365 public static final ParseField TIMEOUT = new ParseField ("timeout" );
6466
6567 static final ObjectParser <Request .Builder , Void > PARSER = new ObjectParser <>(NAME , Request .Builder ::new );
@@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
6870 PARSER .declareString (Request .Builder ::setInputType , INPUT_TYPE );
6971 PARSER .declareObject (Request .Builder ::setTaskSettings , (p , c ) -> p .mapOrdered (), TASK_SETTINGS );
7072 PARSER .declareString (Request .Builder ::setQuery , QUERY );
73+ PARSER .declareBoolean (Request .Builder ::setReturnDocuments , RETURN_DOCUMENTS );
74+ PARSER .declareInt (Request .Builder ::setTopN , TOP_N );
7175 PARSER .declareString (Builder ::setInferenceTimeout , TIMEOUT );
7276 }
7377
@@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8993 private final TaskType taskType ;
9094 private final String inferenceEntityId ;
9195 private final String query ;
96+ private final Boolean returnDocuments ;
97+ private final Integer topN ;
9298 private final List <String > input ;
9399 private final Map <String , Object > taskSettings ;
94100 private final InputType inputType ;
@@ -99,6 +105,8 @@ public Request(
99105 TaskType taskType ,
100106 String inferenceEntityId ,
101107 String query ,
108+ Boolean returnDocuments ,
109+ Integer topN ,
102110 List <String > input ,
103111 Map <String , Object > taskSettings ,
104112 InputType inputType ,
@@ -109,6 +117,8 @@ public Request(
109117 taskType ,
110118 inferenceEntityId ,
111119 query ,
120+ returnDocuments ,
121+ topN ,
112122 input ,
113123 taskSettings ,
114124 inputType ,
@@ -122,6 +132,8 @@ public Request(
122132 TaskType taskType ,
123133 String inferenceEntityId ,
124134 String query ,
135+ Boolean returnDocuments ,
136+ Integer topN ,
125137 List <String > input ,
126138 Map <String , Object > taskSettings ,
127139 InputType inputType ,
@@ -133,6 +145,8 @@ public Request(
133145 this .taskType = taskType ;
134146 this .inferenceEntityId = inferenceEntityId ;
135147 this .query = query ;
148+ this .returnDocuments = returnDocuments ;
149+ this .topN = topN ;
136150 this .input = input ;
137151 this .taskSettings = taskSettings ;
138152 this .inputType = inputType ;
@@ -164,6 +178,14 @@ public Request(StreamInput in) throws IOException {
164178 this .inferenceTimeout = DEFAULT_TIMEOUT ;
165179 }
166180
181+ if (in .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
182+ this .returnDocuments = in .readOptionalBoolean ();
183+ this .topN = in .readOptionalInt ();
184+ } else {
185+ this .returnDocuments = null ;
186+ this .topN = null ;
187+ }
188+
167189 // streaming is not supported yet for transport traffic
168190 this .stream = false ;
169191 }
@@ -184,6 +206,14 @@ public String getQuery() {
184206 return query ;
185207 }
186208
209+ public Boolean getReturnDocuments () {
210+ return returnDocuments ;
211+ }
212+
213+ public Integer getTopN () {
214+ return topN ;
215+ }
216+
187217 public Map <String , Object > getTaskSettings () {
188218 return taskSettings ;
189219 }
@@ -225,6 +255,17 @@ public ActionRequestValidationException validate() {
225255 e .addValidationError (format ("Field [query] cannot be empty for task type [%s]" , TaskType .RERANK ));
226256 return e ;
227257 }
258+ } else if (taskType .equals (TaskType .ANY ) == false ) {
259+ if (returnDocuments != null ) {
260+ var e = new ActionRequestValidationException ();
261+ e .addValidationError (format ("Field [return_documents] cannot be specified for task type [%s]" , taskType ));
262+ return e ;
263+ }
264+ if (topN != null ) {
265+ var e = new ActionRequestValidationException ();
266+ e .addValidationError (format ("Field [top_n] cannot be specified for task type [%s]" , taskType ));
267+ return e ;
268+ }
228269 }
229270
230271 if (taskType .equals (TaskType .TEXT_EMBEDDING ) == false
@@ -258,6 +299,11 @@ public void writeTo(StreamOutput out) throws IOException {
258299 out .writeOptionalString (query );
259300 out .writeTimeValue (inferenceTimeout );
260301 }
302+
303+ if (out .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
304+ out .writeOptionalBoolean (returnDocuments );
305+ out .writeOptionalInt (topN );
306+ }
261307 }
262308
263309 // default for easier testing
@@ -283,6 +329,8 @@ public boolean equals(Object o) {
283329 && taskType == request .taskType
284330 && Objects .equals (inferenceEntityId , request .inferenceEntityId )
285331 && Objects .equals (query , request .query )
332+ && Objects .equals (returnDocuments , request .returnDocuments )
333+ && Objects .equals (topN , request .topN )
286334 && Objects .equals (input , request .input )
287335 && Objects .equals (taskSettings , request .taskSettings )
288336 && inputType == request .inputType
@@ -296,6 +344,8 @@ public int hashCode() {
296344 taskType ,
297345 inferenceEntityId ,
298346 query ,
347+ returnDocuments ,
348+ topN ,
299349 input ,
300350 taskSettings ,
301351 inputType ,
@@ -312,6 +362,8 @@ public static class Builder {
312362 private InputType inputType = InputType .UNSPECIFIED ;
313363 private Map <String , Object > taskSettings = Map .of ();
314364 private String query ;
365+ private Boolean returnDocuments ;
366+ private Integer topN ;
315367 private TimeValue timeout = DEFAULT_TIMEOUT ;
316368 private boolean stream = false ;
317369 private InferenceContext context ;
@@ -338,6 +390,16 @@ public Builder setQuery(String query) {
338390 return this ;
339391 }
340392
393+ public Builder setReturnDocuments (Boolean returnDocuments ) {
394+ this .returnDocuments = returnDocuments ;
395+ return this ;
396+ }
397+
398+ public Builder setTopN (Integer topN ) {
399+ this .topN = topN ;
400+ return this ;
401+ }
402+
341403 public Builder setInputType (InputType inputType ) {
342404 this .inputType = inputType ;
343405 return this ;
@@ -373,7 +435,19 @@ public Builder setContext(InferenceContext context) {
373435 }
374436
375437 public Request build () {
376- return new Request (taskType , inferenceEntityId , query , input , taskSettings , inputType , timeout , stream , context );
438+ return new Request (
439+ taskType ,
440+ inferenceEntityId ,
441+ query ,
442+ returnDocuments ,
443+ topN ,
444+ input ,
445+ taskSettings ,
446+ inputType ,
447+ timeout ,
448+ stream ,
449+ context
450+ );
377451 }
378452 }
379453
@@ -384,6 +458,10 @@ public String toString() {
384458 + this .getInferenceEntityId ()
385459 + ", query="
386460 + this .getQuery ()
461+ + ", returnDocuments="
462+ + this .getReturnDocuments ()
463+ + ", topN="
464+ + this .getTopN ()
387465 + ", input="
388466 + this .getInput ()
389467 + ", taskSettings="
0 commit comments