4646import org .elasticsearch .xpack .esql .expression .function .fulltext .Match ;
4747import org .elasticsearch .xpack .esql .expression .function .fulltext .MatchOperator ;
4848import org .elasticsearch .xpack .esql .expression .function .fulltext .QueryString ;
49+ import org .elasticsearch .xpack .esql .expression .function .scalar .string .Concat ;
4950import org .elasticsearch .xpack .esql .expression .function .scalar .string .Substring ;
5051import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
5152import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .GreaterThan ;
7172import org .elasticsearch .xpack .esql .plan .logical .Row ;
7273import org .elasticsearch .xpack .esql .plan .logical .RrfScoreEval ;
7374import org .elasticsearch .xpack .esql .plan .logical .UnresolvedRelation ;
75+ import org .elasticsearch .xpack .esql .plan .logical .inference .Completion ;
7476import org .elasticsearch .xpack .esql .plan .logical .inference .Rerank ;
7577import org .elasticsearch .xpack .esql .plan .logical .local .EsqlProject ;
7678import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
9294import static org .elasticsearch .xpack .esql .EsqlTestUtils .as ;
9395import static org .elasticsearch .xpack .esql .EsqlTestUtils .configuration ;
9496import static org .elasticsearch .xpack .esql .EsqlTestUtils .emptyInferenceResolution ;
97+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .getAttributeByName ;
9598import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsConstant ;
9699import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsIdentifier ;
97100import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsPattern ;
101+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .referenceAttribute ;
98102import static org .elasticsearch .xpack .esql .EsqlTestUtils .withDefaultLimitWarning ;
99103import static org .elasticsearch .xpack .esql .analysis .Analyzer .NO_FIELDS ;
100104import static org .elasticsearch .xpack .esql .analysis .AnalyzerTestUtils .analyze ;
@@ -3460,7 +3464,7 @@ public void testResolveRerankInferenceId() {
34603464
34613465 {
34623466 LogicalPlan plan = analyze (
3463- " FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3467+ "FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
34643468 "mapping-books.json"
34653469 );
34663470 Rerank rerank = as (as (plan , Limit .class ).child (), Rerank .class );
@@ -3530,16 +3534,13 @@ public void testResolveRerankFields() {
35303534 Filter filter = as (drop .child (), Filter .class );
35313535 EsRelation relation = as (filter .child (), EsRelation .class );
35323536
3533- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3534- assertThat (titleAttribute , notNullValue ());
3537+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3538+ assertThat (getAttributeByName ( relation . output (), "title" ) , notNullValue ());
35353539
35363540 assertThat (rerank .queryText (), equalTo (string ("italian food recipe" )));
35373541 assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
35383542 assertThat (rerank .rerankFields (), equalTo (List .of (alias ("title" , titleAttribute ))));
3539- assertThat (
3540- rerank .scoreAttribute (),
3541- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3542- );
3543+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
35433544 }
35443545
35453546 {
@@ -3559,15 +3560,11 @@ public void testResolveRerankFields() {
35593560 assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
35603561
35613562 assertThat (rerank .rerankFields (), hasSize (3 ));
3562- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3563+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
35633564 assertThat (titleAttribute , notNullValue ());
35643565 assertThat (rerank .rerankFields ().get (0 ), equalTo (alias ("title" , titleAttribute )));
35653566
3566- Attribute descriptionAttribute = relation .output ()
3567- .stream ()
3568- .filter (attribute -> attribute .name ().equals ("description" ))
3569- .findFirst ()
3570- .get ();
3567+ Attribute descriptionAttribute = getAttributeByName (relation .output (), "description" );
35713568 assertThat (descriptionAttribute , notNullValue ());
35723569 Alias descriptionAlias = rerank .rerankFields ().get (1 );
35733570 assertThat (descriptionAlias .name (), equalTo ("description" ));
@@ -3576,13 +3573,11 @@ public void testResolveRerankFields() {
35763573 equalTo (List .of (descriptionAttribute , literal (0 ), literal (100 )))
35773574 );
35783575
3579- Attribute yearAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "year" )). findFirst (). get ( );
3576+ Attribute yearAttribute = getAttributeByName ( relation .output (), "year" );
35803577 assertThat (yearAttribute , notNullValue ());
35813578 assertThat (rerank .rerankFields ().get (2 ), equalTo (alias ("yearRenamed" , yearAttribute )));
3582- assertThat (
3583- rerank .scoreAttribute (),
3584- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3585- );
3579+
3580+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
35863581 }
35873582
35883583 {
@@ -3614,11 +3609,7 @@ public void testResolveRerankScoreField() {
36143609 Filter filter = as (rerank .child (), Filter .class );
36153610 EsRelation relation = as (filter .child (), EsRelation .class );
36163611
3617- Attribute metadataScoreAttribute = relation .output ()
3618- .stream ()
3619- .filter (attr -> attr .name ().equals (MetadataAttribute .SCORE ))
3620- .findFirst ()
3621- .get ();
3612+ Attribute metadataScoreAttribute = getAttributeByName (relation .output (), MetadataAttribute .SCORE );
36223613 assertThat (rerank .scoreAttribute (), equalTo (metadataScoreAttribute ));
36233614 assertThat (rerank .output (), hasItem (metadataScoreAttribute ));
36243615 }
@@ -3642,6 +3633,116 @@ public void testResolveRerankScoreField() {
36423633 }
36433634 }
36443635
3636+ public void testResolveCompletionInferenceId () {
3637+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3638+
3639+ LogicalPlan plan = analyze ("""
3640+ FROM books METADATA _score
3641+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3642+ """ , "mapping-books.json" );
3643+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3644+ assertThat (completion .inferenceId (), equalTo (string ("completion-inference-id" )));
3645+ }
3646+
3647+ public void testResolveCompletionInferenceIdInvalidTaskType () {
3648+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3649+
3650+ assertError (
3651+ """
3652+ FROM books METADATA _score
3653+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `reranking-inference-id`
3654+ """ ,
3655+ "mapping-books.json" ,
3656+ new QueryParams (),
3657+ "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3658+ + " Only inference endpoints with the task type [completion] are supported"
3659+ );
3660+ }
3661+
3662+ public void testResolveCompletionInferenceMissingInferenceId () {
3663+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3664+
3665+ assertError ("""
3666+ FROM books METADATA _score
3667+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `unknown-inference-id`
3668+ """ , "mapping-books.json" , new QueryParams (), "unresolved inference [unknown-inference-id]" );
3669+ }
3670+
3671+ public void testResolveCompletionInferenceIdResolutionError () {
3672+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3673+
3674+ assertError ("""
3675+ FROM books METADATA _score
3676+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `error-inference-id`
3677+ """ , "mapping-books.json" , new QueryParams (), "error with inference resolution" );
3678+ }
3679+
3680+ public void testResolveCompletionTargetField () {
3681+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3682+
3683+ LogicalPlan plan = analyze ("""
3684+ FROM books METADATA _score
3685+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS translation
3686+ """ , "mapping-books.json" );
3687+
3688+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3689+ assertThat (completion .targetField (), equalTo (referenceAttribute ("translation" , DataType .TEXT )));
3690+ }
3691+
3692+ public void testResolveCompletionDefaultTargetField () {
3693+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3694+
3695+ LogicalPlan plan = analyze ("""
3696+ FROM books METADATA _score
3697+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3698+ """ , "mapping-books.json" );
3699+
3700+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3701+ assertThat (completion .targetField (), equalTo (referenceAttribute ("completion" , DataType .TEXT )));
3702+ }
3703+
3704+ public void testResolveCompletionPrompt () {
3705+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3706+
3707+ LogicalPlan plan = analyze ("""
3708+ FROM books METADATA _score
3709+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3710+ """ , "mapping-books.json" );
3711+
3712+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3713+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3714+
3715+ assertThat (
3716+ as (completion .prompt (), Concat .class ).children (),
3717+ equalTo (List .of (string ("Translate the following text in French\n " ), getAttributeByName (esRelation .output (), "description" )))
3718+ );
3719+ }
3720+
3721+ public void testResolveCompletionPromptInvalidType () {
3722+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3723+
3724+ assertError ("""
3725+ FROM books METADATA _score
3726+ | COMPLETION LENGTH(description) WITH `completion-inference-id`
3727+ """ , "mapping-books.json" , new QueryParams (), "prompt must be of type [text] but is [integer]" );
3728+ }
3729+
3730+ public void testResolveCompletionOutputField () {
3731+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3732+
3733+ LogicalPlan plan = analyze ("""
3734+ FROM books METADATA _score
3735+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS description
3736+ """ , "mapping-books.json" );
3737+
3738+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3739+ assertThat (completion .targetField (), equalTo (referenceAttribute ("description" , DataType .TEXT )));
3740+
3741+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3742+ assertThat (getAttributeByName (completion .output (), "description" ), equalTo (completion .targetField ()));
3743+ assertThat (getAttributeByName (esRelation .output (), "description" ), not (equalTo (completion .targetField ())));
3744+ }
3745+
36453746 @ Override
36463747 protected IndexAnalyzers createDefaultIndexAnalyzers () {
36473748 return super .createDefaultIndexAnalyzers ();
0 commit comments