1010import  org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
1111import  org .elasticsearch .common .io .stream .StreamInput ;
1212import  org .elasticsearch .common .io .stream .StreamOutput ;
13+ import  org .elasticsearch .common .lucene .BytesRefs ;
14+ import  org .elasticsearch .inference .ModelConfigurations ;
1315import  org .elasticsearch .inference .TaskType ;
1416import  org .elasticsearch .xpack .esql .capabilities .TelemetryAware ;
17+ import  org .elasticsearch .xpack .esql .core .capabilities .Unresolvable ;
1518import  org .elasticsearch .xpack .esql .core .expression .Attribute ;
1619import  org .elasticsearch .xpack .esql .core .expression .Expression ;
20+ import  org .elasticsearch .xpack .esql .core .expression .FoldContext ;
21+ import  org .elasticsearch .xpack .esql .core .expression .Literal ;
1722import  org .elasticsearch .xpack .esql .core .expression .NameId ;
1823import  org .elasticsearch .xpack .esql .core .tree .NodeInfo ;
1924import  org .elasticsearch .xpack .esql .core .tree .Source ;
25+ import  org .elasticsearch .xpack .esql .core .type .DataType ;
2026import  org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
2127import  org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
2228import  org .elasticsearch .xpack .esql .plan .logical .inference .InferencePlan ;
@@ -36,13 +42,26 @@ public class DenseVectorEmbedding extends InferencePlan<DenseVectorEmbedding> im
3642    );
3743
3844    private  final  Expression  input ;
45+     private  final  Expression  dimensions ;
3946    private  final  Attribute  targetField ;
4047    private  List <Attribute > lazyOutput ;
4148
4249    public  DenseVectorEmbedding (Source  source , LogicalPlan  child , Expression  inferenceId , Expression  input , Attribute  targetField ) {
50+         this (source , child , inferenceId , new  UnresolvedDimensions (inferenceId ), input , targetField );
51+     }
52+ 
53+     DenseVectorEmbedding (
54+         Source  source ,
55+         LogicalPlan  child ,
56+         Expression  inferenceId ,
57+         Expression  dimensions ,
58+         Expression  input ,
59+         Attribute  targetField 
60+     ) {
4361        super (source , child , inferenceId );
4462        this .input  = input ;
4563        this .targetField  = targetField ;
64+         this .dimensions  = dimensions ;
4665    }
4766
4867    public  DenseVectorEmbedding (StreamInput  in ) throws  IOException  {
@@ -51,6 +70,7 @@ public DenseVectorEmbedding(StreamInput in) throws IOException {
5170            in .readNamedWriteable (LogicalPlan .class ),
5271            in .readNamedWriteable (Expression .class ),
5372            in .readNamedWriteable (Expression .class ),
73+             in .readNamedWriteable (Expression .class ),
5474            in .readNamedWriteable (Attribute .class )
5575        );
5676    }
@@ -60,6 +80,7 @@ public void writeTo(StreamOutput out) throws IOException {
6080        source ().writeTo (out );
6181        out .writeNamedWriteable (child ());
6282        out .writeNamedWriteable (inferenceId ());
83+         out .writeNamedWriteable (dimensions );
6384        out .writeNamedWriteable (input );
6485        out .writeNamedWriteable (targetField );
6586    }
@@ -77,6 +98,10 @@ public TaskType taskType() {
7798        return  TaskType .TEXT_EMBEDDING ;
7899    }
79100
101+     public  Expression  dimensions () {
102+         return  dimensions ;
103+     }
104+ 
80105    @ Override 
81106    public  String  getWriteableName () {
82107        return  ENTRY .name ;
@@ -98,7 +123,7 @@ public List<Attribute> generatedAttributes() {
98123    @ Override 
99124    public  DenseVectorEmbedding  withGeneratedNames (List <String > newNames ) {
100125        checkNumberOfNewNames (newNames );
101-         return  new  DenseVectorEmbedding (source (), child (), inferenceId (), input , this .renameTargetField (newNames .get (0 )));
126+         return  new  DenseVectorEmbedding (source (), child (), inferenceId (), dimensions ,  input , this .renameTargetField (newNames .get (0 )));
102127    }
103128
104129    private  Attribute  renameTargetField (String  newName ) {
@@ -111,22 +136,45 @@ private Attribute renameTargetField(String newName) {
111136
112137    @ Override 
113138    public  boolean  expressionsResolved () {
114-         return  super .expressionsResolved () && input .resolved () && targetField .resolved ();
139+         return  super .expressionsResolved () && input .resolved () && targetField .resolved () &&  dimensions . resolved () ;
115140    }
116141
117142    @ Override 
118143    public  DenseVectorEmbedding  withInferenceId (Expression  newInferenceId ) {
119-         return  new  DenseVectorEmbedding (source (), child (), newInferenceId , input , targetField );
144+         return  new  DenseVectorEmbedding (source (), child (), newInferenceId , dimensions , input , targetField );
145+     }
146+ 
147+     public  DenseVectorEmbedding  withDimensions (Expression  newDimensions ) {
148+         return  new  DenseVectorEmbedding (source (), child (), inferenceId (), newDimensions , input , targetField );
149+     }
150+ 
151+     public  DenseVectorEmbedding  withTargetField (Attribute  targetField ) {
152+         return  new  DenseVectorEmbedding (source (), child (), inferenceId (), dimensions , input , targetField );
153+     }
154+ 
155+     @ Override 
156+     public  DenseVectorEmbedding  withModelConfigurations (ModelConfigurations  modelConfig ) {
157+         boolean  hasChanged  = false ;
158+         Expression  newDimensions  = dimensions ;
159+ 
160+         if  (dimensions .resolved () == false 
161+             && modelConfig .getServiceSettings () != null 
162+             && modelConfig .getServiceSettings ().dimensions () > 0 ) {
163+             hasChanged  = true ;
164+             newDimensions  = new  Literal (Source .EMPTY , modelConfig .getServiceSettings ().dimensions (), DataType .INTEGER );
165+         }
166+ 
167+         return  hasChanged  ? withDimensions (newDimensions ) : this ;
120168    }
121169
122170    @ Override 
123171    public  DenseVectorEmbedding  replaceChild (LogicalPlan  newChild ) {
124-         return  new  DenseVectorEmbedding (source (), newChild , inferenceId (), input , targetField );
172+         return  new  DenseVectorEmbedding (source (), newChild , inferenceId (), dimensions ,  input , targetField );
125173    }
126174
127175    @ Override 
128176    protected  NodeInfo <? extends  LogicalPlan > info () {
129-         return  NodeInfo .create (this , DenseVectorEmbedding ::new , child (), inferenceId (), input , targetField );
177+         return  NodeInfo .create (this , DenseVectorEmbedding ::new , child (), inferenceId (), dimensions ,  input , targetField );
130178    }
131179
132180    @ Override 
@@ -135,11 +183,28 @@ public boolean equals(Object o) {
135183        if  (o  == null  || getClass () != o .getClass ()) return  false ;
136184        if  (super .equals (o ) == false ) return  false ;
137185        DenseVectorEmbedding  that  = (DenseVectorEmbedding ) o ;
138-         return  Objects .equals (input , that .input ) && Objects .equals (targetField , that .targetField );
186+         return  Objects .equals (input , that .input )
187+             && Objects .equals (dimensions , that .dimensions )
188+             && Objects .equals (targetField , that .targetField );
139189    }
140190
141191    @ Override 
142192    public  int  hashCode () {
143-         return  Objects .hash (super .hashCode (), input , targetField );
193+         return  Objects .hash (super .hashCode (), input , targetField , dimensions );
194+     }
195+ 
196+     private  static  class  UnresolvedDimensions  extends  Literal  implements  Unresolvable  {
197+ 
198+         private  final  String  inferenceId ;
199+ 
200+         private  UnresolvedDimensions (Expression  inferenceId ) {
201+             super (Source .EMPTY , null , DataType .NULL );
202+             this .inferenceId  = BytesRefs .toString (inferenceId .fold (FoldContext .small ()));
203+         }
204+ 
205+         @ Override 
206+         public  String  unresolvedMessage () {
207+             return  "Dimensions cannot be resolved for inference endpoint["  + inferenceId  + "]" ;
208+         }
144209    }
145210}
0 commit comments