@@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
6060        public  static  final  ParseField  INPUT_TYPE  = new  ParseField ("input_type" );
6161        public  static  final  ParseField  TASK_SETTINGS  = new  ParseField ("task_settings" );
6262        public  static  final  ParseField  QUERY  = new  ParseField ("query" );
63+         public  static  final  ParseField  RETURN_DOCUMENTS  = new  ParseField ("return_documents" );
64+         public  static  final  ParseField  TOP_N  = new  ParseField ("top_n" );
6365        public  static  final  ParseField  TIMEOUT  = new  ParseField ("timeout" );
6466
6567        static  final  ObjectParser <Request .Builder , Void > PARSER  = new  ObjectParser <>(NAME , Request .Builder ::new );
@@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
6870            PARSER .declareString (Request .Builder ::setInputType , INPUT_TYPE );
6971            PARSER .declareObject (Request .Builder ::setTaskSettings , (p , c ) -> p .mapOrdered (), TASK_SETTINGS );
7072            PARSER .declareString (Request .Builder ::setQuery , QUERY );
73+             PARSER .declareBoolean (Request .Builder ::setReturnDocuments , RETURN_DOCUMENTS );
74+             PARSER .declareInt (Request .Builder ::setTopN , TOP_N );
7175            PARSER .declareString (Builder ::setInferenceTimeout , TIMEOUT );
7276        }
7377
@@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8993        private  final  TaskType  taskType ;
9094        private  final  String  inferenceEntityId ;
9195        private  final  String  query ;
96+         private  final  Boolean  returnDocuments ;
97+         private  final  Integer  topN ;
9298        private  final  List <String > input ;
9399        private  final  Map <String , Object > taskSettings ;
94100        private  final  InputType  inputType ;
@@ -99,6 +105,8 @@ public Request(
99105            TaskType  taskType ,
100106            String  inferenceEntityId ,
101107            String  query ,
108+             Boolean  returnDocuments ,
109+             Integer  topN ,
102110            List <String > input ,
103111            Map <String , Object > taskSettings ,
104112            InputType  inputType ,
@@ -109,6 +117,8 @@ public Request(
109117                taskType ,
110118                inferenceEntityId ,
111119                query ,
120+                 returnDocuments ,
121+                 topN ,
112122                input ,
113123                taskSettings ,
114124                inputType ,
@@ -122,6 +132,8 @@ public Request(
122132            TaskType  taskType ,
123133            String  inferenceEntityId ,
124134            String  query ,
135+             Boolean  returnDocuments ,
136+             Integer  topN ,
125137            List <String > input ,
126138            Map <String , Object > taskSettings ,
127139            InputType  inputType ,
@@ -133,6 +145,8 @@ public Request(
133145            this .taskType  = taskType ;
134146            this .inferenceEntityId  = inferenceEntityId ;
135147            this .query  = query ;
148+             this .returnDocuments  = returnDocuments ;
149+             this .topN  = topN ;
136150            this .input  = input ;
137151            this .taskSettings  = taskSettings ;
138152            this .inputType  = inputType ;
@@ -164,6 +178,15 @@ public Request(StreamInput in) throws IOException {
164178                this .inferenceTimeout  = DEFAULT_TIMEOUT ;
165179            }
166180
181+             if  (in .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED )
182+                 || in .getTransportVersion ().isPatchFrom (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
183+                 this .returnDocuments  = in .readOptionalBoolean ();
184+                 this .topN  = in .readOptionalInt ();
185+             } else  {
186+                 this .returnDocuments  = null ;
187+                 this .topN  = null ;
188+             }
189+ 
167190            // streaming is not supported yet for transport traffic 
168191            this .stream  = false ;
169192        }
@@ -184,6 +207,14 @@ public String getQuery() {
184207            return  query ;
185208        }
186209
210+         public  Boolean  getReturnDocuments () {
211+             return  returnDocuments ;
212+         }
213+ 
214+         public  Integer  getTopN () {
215+             return  topN ;
216+         }
217+ 
187218        public  Map <String , Object > getTaskSettings () {
188219            return  taskSettings ;
189220        }
@@ -225,6 +256,17 @@ public ActionRequestValidationException validate() {
225256                    e .addValidationError (format ("Field [query] cannot be empty for task type [%s]" , TaskType .RERANK ));
226257                    return  e ;
227258                }
259+             } else  if  (taskType .equals (TaskType .ANY ) == false ) {
260+                 if  (returnDocuments  != null ) {
261+                     var  e  = new  ActionRequestValidationException ();
262+                     e .addValidationError (format ("Field [return_documents] cannot be specified for task type [%s]" , taskType ));
263+                     return  e ;
264+                 }
265+                 if  (topN  != null ) {
266+                     var  e  = new  ActionRequestValidationException ();
267+                     e .addValidationError (format ("Field [top_n] cannot be specified for task type [%s]" , taskType ));
268+                     return  e ;
269+                 }
228270            }
229271
230272            if  (taskType .equals (TaskType .TEXT_EMBEDDING ) == false 
@@ -258,6 +300,12 @@ public void writeTo(StreamOutput out) throws IOException {
258300                out .writeOptionalString (query );
259301                out .writeTimeValue (inferenceTimeout );
260302            }
303+ 
304+             if  (out .getTransportVersion ().onOrAfter (TransportVersions .RERANK_COMMON_OPTIONS_ADDED )
305+                 || out .getTransportVersion ().isPatchFrom (TransportVersions .RERANK_COMMON_OPTIONS_ADDED_8_19 )) {
306+                 out .writeOptionalBoolean (returnDocuments );
307+                 out .writeOptionalInt (topN );
308+             }
261309        }
262310
263311        // default for easier testing 
@@ -283,6 +331,8 @@ public boolean equals(Object o) {
283331                && taskType  == request .taskType 
284332                && Objects .equals (inferenceEntityId , request .inferenceEntityId )
285333                && Objects .equals (query , request .query )
334+                 && Objects .equals (returnDocuments , request .returnDocuments )
335+                 && Objects .equals (topN , request .topN )
286336                && Objects .equals (input , request .input )
287337                && Objects .equals (taskSettings , request .taskSettings )
288338                && inputType  == request .inputType 
@@ -296,6 +346,8 @@ public int hashCode() {
296346                taskType ,
297347                inferenceEntityId ,
298348                query ,
349+                 returnDocuments ,
350+                 topN ,
299351                input ,
300352                taskSettings ,
301353                inputType ,
@@ -312,6 +364,8 @@ public static class Builder {
312364            private  InputType  inputType  = InputType .UNSPECIFIED ;
313365            private  Map <String , Object > taskSettings  = Map .of ();
314366            private  String  query ;
367+             private  Boolean  returnDocuments ;
368+             private  Integer  topN ;
315369            private  TimeValue  timeout  = DEFAULT_TIMEOUT ;
316370            private  boolean  stream  = false ;
317371            private  InferenceContext  context ;
@@ -338,6 +392,16 @@ public Builder setQuery(String query) {
338392                return  this ;
339393            }
340394
395+             public  Builder  setReturnDocuments (Boolean  returnDocuments ) {
396+                 this .returnDocuments  = returnDocuments ;
397+                 return  this ;
398+             }
399+ 
400+             public  Builder  setTopN (Integer  topN ) {
401+                 this .topN  = topN ;
402+                 return  this ;
403+             }
404+ 
341405            public  Builder  setInputType (InputType  inputType ) {
342406                this .inputType  = inputType ;
343407                return  this ;
@@ -373,7 +437,19 @@ public Builder setContext(InferenceContext context) {
373437            }
374438
375439            public  Request  build () {
376-                 return  new  Request (taskType , inferenceEntityId , query , input , taskSettings , inputType , timeout , stream , context );
440+                 return  new  Request (
441+                     taskType ,
442+                     inferenceEntityId ,
443+                     query ,
444+                     returnDocuments ,
445+                     topN ,
446+                     input ,
447+                     taskSettings ,
448+                     inputType ,
449+                     timeout ,
450+                     stream ,
451+                     context 
452+                 );
377453            }
378454        }
379455
@@ -384,6 +460,10 @@ public String toString() {
384460                + this .getInferenceEntityId ()
385461                + ", query=" 
386462                + this .getQuery ()
463+                 + ", returnDocuments=" 
464+                 + this .getReturnDocuments ()
465+                 + ", topN=" 
466+                 + this .getTopN ()
387467                + ", input=" 
388468                + this .getInput ()
389469                + ", taskSettings=" 
0 commit comments