@@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
6060 public static final ParseField INPUT_TYPE = new ParseField ("input_type" );
6161 public static final ParseField TASK_SETTINGS = new ParseField ("task_settings" );
6262 public static final ParseField QUERY = new ParseField ("query" );
63+ public static final ParseField RETURN_DOCUMENTS = new ParseField ("return_documents" );
64+ public static final ParseField TOP_N = new ParseField ("top_n" );
6365 public static final ParseField TIMEOUT = new ParseField ("timeout" );
6466
6567 static final ObjectParser <Request .Builder , Void > PARSER = new ObjectParser <>(NAME , Request .Builder ::new );
@@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
6870 PARSER .declareString (Request .Builder ::setInputType , INPUT_TYPE );
6971 PARSER .declareObject (Request .Builder ::setTaskSettings , (p , c ) -> p .mapOrdered (), TASK_SETTINGS );
7072 PARSER .declareString (Request .Builder ::setQuery , QUERY );
73+ PARSER .declareBoolean (Request .Builder ::setReturnDocuments , RETURN_DOCUMENTS );
74+ PARSER .declareInt (Request .Builder ::setTopN , TOP_N );
7175 PARSER .declareString (Builder ::setInferenceTimeout , TIMEOUT );
7276 }
7377
@@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8993 private final TaskType taskType ;
9094 private final String inferenceEntityId ;
9195 private final String query ;
96+ private final Boolean returnDocuments ;
97+ private final Integer topN ;
9298 private final List <String > input ;
9399 private final Map <String , Object > taskSettings ;
94100 private final InputType inputType ;
@@ -99,6 +105,8 @@ public Request(
99105 TaskType taskType ,
100106 String inferenceEntityId ,
101107 String query ,
108+ Boolean returnDocuments ,
109+ Integer topN ,
102110 List <String > input ,
103111 Map <String , Object > taskSettings ,
104112 InputType inputType ,
@@ -109,6 +117,8 @@ public Request(
109117 taskType ,
110118 inferenceEntityId ,
111119 query ,
120+ returnDocuments ,
121+ topN ,
112122 input ,
113123 taskSettings ,
114124 inputType ,
@@ -122,6 +132,8 @@ public Request(
122132 TaskType taskType ,
123133 String inferenceEntityId ,
124134 String query ,
135+ Boolean returnDocuments ,
136+ Integer topN ,
125137 List <String > input ,
126138 Map <String , Object > taskSettings ,
127139 InputType inputType ,
@@ -133,6 +145,8 @@ public Request(
133145 this .taskType = taskType ;
134146 this .inferenceEntityId = inferenceEntityId ;
135147 this .query = query ;
148+ this .returnDocuments = returnDocuments ;
149+ this .topN = topN ;
136150 this .input = input ;
137151 this .taskSettings = taskSettings ;
138152 this .inputType = inputType ;
@@ -164,6 +178,15 @@ public Request(StreamInput in) throws IOException {
164178 this .inferenceTimeout = DEFAULT_TIMEOUT ;
165179 }
166180
181+ if (in .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED )
182+ || in .getTransportVersion ().isPatchFrom (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
183+ this .returnDocuments = in .readOptionalBoolean ();
184+ this .topN = in .readOptionalInt ();
185+ } else {
186+ this .returnDocuments = null ;
187+ this .topN = null ;
188+ }
189+
167190 // streaming is not supported yet for transport traffic
168191 this .stream = false ;
169192 }
@@ -184,6 +207,14 @@ public String getQuery() {
184207 return query ;
185208 }
186209
210+ public Boolean getReturnDocuments () {
211+ return returnDocuments ;
212+ }
213+
214+ public Integer getTopN () {
215+ return topN ;
216+ }
217+
187218 public Map <String , Object > getTaskSettings () {
188219 return taskSettings ;
189220 }
@@ -225,6 +256,17 @@ public ActionRequestValidationException validate() {
225256 e .addValidationError (format ("Field [query] cannot be empty for task type [%s]" , TaskType .RERANK ));
226257 return e ;
227258 }
259+ } else if (taskType .equals (TaskType .ANY ) == false ) {
260+ if (returnDocuments != null ) {
261+ var e = new ActionRequestValidationException ();
262+ e .addValidationError (format ("Field [return_documents] cannot be specified for task type [%s]" , taskType ));
263+ return e ;
264+ }
265+ if (topN != null ) {
266+ var e = new ActionRequestValidationException ();
267+ e .addValidationError (format ("Field [top_n] cannot be specified for task type [%s]" , taskType ));
268+ return e ;
269+ }
228270 }
229271
230272 if (taskType .equals (TaskType .TEXT_EMBEDDING ) == false
@@ -258,6 +300,12 @@ public void writeTo(StreamOutput out) throws IOException {
258300 out .writeOptionalString (query );
259301 out .writeTimeValue (inferenceTimeout );
260302 }
303+
304+ if (out .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED )
305+ || out .getTransportVersion ().isPatchFrom (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
306+ out .writeOptionalBoolean (returnDocuments );
307+ out .writeOptionalInt (topN );
308+ }
261309 }
262310
263311 // default for easier testing
@@ -283,6 +331,8 @@ public boolean equals(Object o) {
283331 && taskType == request .taskType
284332 && Objects .equals (inferenceEntityId , request .inferenceEntityId )
285333 && Objects .equals (query , request .query )
334+ && Objects .equals (returnDocuments , request .returnDocuments )
335+ && Objects .equals (topN , request .topN )
286336 && Objects .equals (input , request .input )
287337 && Objects .equals (taskSettings , request .taskSettings )
288338 && inputType == request .inputType
@@ -296,6 +346,8 @@ public int hashCode() {
296346 taskType ,
297347 inferenceEntityId ,
298348 query ,
349+ returnDocuments ,
350+ topN ,
299351 input ,
300352 taskSettings ,
301353 inputType ,
@@ -312,6 +364,8 @@ public static class Builder {
312364 private InputType inputType = InputType .UNSPECIFIED ;
313365 private Map <String , Object > taskSettings = Map .of ();
314366 private String query ;
367+ private Boolean returnDocuments ;
368+ private Integer topN ;
315369 private TimeValue timeout = DEFAULT_TIMEOUT ;
316370 private boolean stream = false ;
317371 private InferenceContext context ;
@@ -338,6 +392,16 @@ public Builder setQuery(String query) {
338392 return this ;
339393 }
340394
395+ public Builder setReturnDocuments (Boolean returnDocuments ) {
396+ this .returnDocuments = returnDocuments ;
397+ return this ;
398+ }
399+
400+ public Builder setTopN (Integer topN ) {
401+ this .topN = topN ;
402+ return this ;
403+ }
404+
341405 public Builder setInputType (InputType inputType ) {
342406 this .inputType = inputType ;
343407 return this ;
@@ -373,7 +437,19 @@ public Builder setContext(InferenceContext context) {
373437 }
374438
375439 public Request build () {
376- return new Request (taskType , inferenceEntityId , query , input , taskSettings , inputType , timeout , stream , context );
440+ return new Request (
441+ taskType ,
442+ inferenceEntityId ,
443+ query ,
444+ returnDocuments ,
445+ topN ,
446+ input ,
447+ taskSettings ,
448+ inputType ,
449+ timeout ,
450+ stream ,
451+ context
452+ );
377453 }
378454 }
379455
@@ -384,6 +460,10 @@ public String toString() {
384460 + this .getInferenceEntityId ()
385461 + ", query="
386462 + this .getQuery ()
463+ + ", returnDocuments="
464+ + this .getReturnDocuments ()
465+ + ", topN="
466+ + this .getTopN ()
387467 + ", input="
388468 + this .getInput ()
389469 + ", taskSettings="
0 commit comments