|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.esql.expression.function.vector; |
9 | 9 |
|
| 10 | +import org.apache.logging.log4j.LogManager; |
| 11 | +import org.apache.logging.log4j.Logger; |
10 | 12 | import org.elasticsearch.common.io.stream.NamedWriteableRegistry; |
11 | 13 | import org.elasticsearch.common.io.stream.StreamInput; |
12 | 14 | import org.elasticsearch.common.io.stream.StreamOutput; |
13 | 15 | import org.elasticsearch.index.query.QueryBuilder; |
| 16 | +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; |
14 | 17 | import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; |
15 | 18 | import org.elasticsearch.xpack.esql.common.Failures; |
16 | 19 | import org.elasticsearch.xpack.esql.core.InvalidArgumentException; |
17 | 20 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
18 | 21 | import org.elasticsearch.xpack.esql.core.expression.FoldContext; |
| 22 | +import org.elasticsearch.xpack.esql.core.expression.Literal; |
19 | 23 | import org.elasticsearch.xpack.esql.core.expression.MapExpression; |
20 | 24 | import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; |
21 | 25 | import org.elasticsearch.xpack.esql.core.querydsl.query.Query; |
|
46 | 50 | import java.util.function.BiConsumer; |
47 | 51 |
|
48 | 52 | import static java.util.Map.entry; |
| 53 | +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; |
49 | 54 | import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; |
50 | 55 | import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; |
51 | 56 | import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; |
|
56 | 61 | import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; |
57 | 62 | import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; |
58 | 63 | import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; |
59 | | -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; |
60 | 64 | import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; |
61 | 65 | import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; |
62 | 66 | import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; |
63 | 67 | import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; |
| 68 | +import static org.elasticsearch.xpack.esql.expression.function.FunctionUtils.resolveTypeQuery; |
64 | 69 |
|
65 | 70 | public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { |
| 71 | + private final Logger log = LogManager.getLogger(getClass()); |
66 | 72 |
|
67 | 73 | public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); |
68 | 74 |
|
@@ -189,9 +195,16 @@ private TypeResolution resolveField() { |
189 | 195 | } |
190 | 196 |
|
191 | 197 | private TypeResolution resolveQuery() { |
192 | | - return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( |
193 | | - isNotNullAndFoldable(query(), sourceText(), SECOND) |
194 | | - ); |
| 198 | + TypeResolution result = isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector") |
| 199 | + .and(isNotNull(query(), sourceText(), SECOND)); |
| 200 | + if (result.unresolved()) { |
| 201 | + return result; |
| 202 | + } |
| 203 | + result = resolveTypeQuery(query(), sourceText()); |
| 204 | + if (result.equals(TypeResolution.TYPE_RESOLVED) == false) { |
| 205 | + return result; |
| 206 | + } |
| 207 | + return TypeResolution.TYPE_RESOLVED; |
195 | 208 | } |
196 | 209 |
|
197 | 210 | private TypeResolution resolveK() { |
@@ -235,19 +248,59 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException { |
235 | 248 | return matchOptions; |
236 | 249 | } |
237 | 250 |
|
| 251 | + @Override |
| 252 | + public boolean partiallyFoldable() { |
| 253 | + return true; |
| 254 | + } |
| 255 | + |
| 256 | + @Override |
| 257 | + public Expression partiallyFold(FoldContext ctx) { |
| 258 | + if (k instanceof Literal) { |
| 259 | + // already folded, return self |
| 260 | + return this; |
| 261 | + } |
| 262 | + Object foldedK = k.fold(ctx); |
| 263 | + if (foldedK instanceof Number == false) { |
| 264 | + throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k())); |
| 265 | + } |
| 266 | + List<Expression> newChildren = new ArrayList<>(this.children()); |
| 267 | + newChildren.set(2, new Literal(source(), foldedK, INTEGER)); |
| 268 | + Expression replaced = this.replaceChildren(newChildren); |
| 269 | + log.error("Partially folded knn function [{}] with k value [{}]", replaced, foldedK); |
| 270 | + return replaced; |
| 271 | + } |
| 272 | + |
| 273 | + @Override |
| 274 | + public List<Number> queryAsObject() { |
| 275 | + // we need to check that we got a list and every element in the list is a number |
| 276 | + Expression query = query(); |
| 277 | + if (query instanceof Literal literal) { |
| 278 | + @SuppressWarnings("unchecked") |
| 279 | + List<Number> result = ((List<Number>) literal.value()); |
| 280 | + return result; |
| 281 | + } |
| 282 | + throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query)); |
| 283 | + } |
| 284 | + |
| 285 | + int getKIntValue() { |
| 286 | + if (k() instanceof Literal literal) { |
| 287 | + return (int) (Number) literal.value(); |
| 288 | + } |
| 289 | + throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k())); |
| 290 | + } |
| 291 | + |
238 | 292 | @Override |
239 | 293 | protected Query translate(TranslatorHandler handler) { |
240 | 294 | var fieldAttribute = Match.fieldAsFieldAttribute(field()); |
241 | 295 |
|
242 | 296 | Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); |
243 | 297 | String fieldName = getNameFromFieldAttribute(fieldAttribute); |
244 | | - @SuppressWarnings("unchecked") |
245 | | - List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */); |
| 298 | + List<Number> queryFolded = queryAsObject(); |
246 | 299 | float[] queryAsFloats = new float[queryFolded.size()]; |
247 | 300 | for (int i = 0; i < queryFolded.size(); i++) { |
248 | 301 | queryAsFloats[i] = queryFolded.get(i).floatValue(); |
249 | 302 | } |
250 | | - int kValue = ((Number) k().fold(FoldContext.small())).intValue(); |
| 303 | + int kValue = getKIntValue(); |
251 | 304 |
|
252 | 305 | Map<String, Object> opts = queryOptions(); |
253 | 306 | opts.put(K_FIELD.getPreferredName(), kValue); |
@@ -322,12 +375,13 @@ public boolean equals(Object o) { |
322 | 375 | Knn knn = (Knn) o; |
323 | 376 | return Objects.equals(field(), knn.field()) |
324 | 377 | && Objects.equals(query(), knn.query()) |
325 | | - && Objects.equals(queryBuilder(), knn.queryBuilder()); |
| 378 | + && Objects.equals(queryBuilder(), knn.queryBuilder()) |
| 379 | + && Objects.equals(k(), knn.k()); |
326 | 380 | } |
327 | 381 |
|
328 | 382 | @Override |
329 | 383 | public int hashCode() { |
330 | | - return Objects.hash(field(), query(), queryBuilder()); |
| 384 | + return Objects.hash(field(), query(), queryBuilder(), k()); |
331 | 385 | } |
332 | 386 |
|
333 | 387 | } |
0 commit comments