|
13 | 13 | import org.apache.logging.log4j.Logger; |
14 | 14 | import org.apache.lucene.index.LeafReaderContext; |
15 | 15 | import org.apache.lucene.search.TotalHits; |
| 16 | +import org.apache.lucene.util.automaton.CharacterRunAutomaton; |
16 | 17 | import org.elasticsearch.common.bytes.BytesReference; |
17 | 18 | import org.elasticsearch.common.regex.Regex; |
18 | 19 | import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader; |
19 | 20 | import org.elasticsearch.index.fieldvisitor.StoredFieldLoader; |
20 | 21 | import org.elasticsearch.index.mapper.IdLoader; |
| 22 | +import org.elasticsearch.index.mapper.MappedFieldType; |
21 | 23 | import org.elasticsearch.index.mapper.SourceLoader; |
22 | | -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; |
23 | | -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; |
24 | 24 | import org.elasticsearch.search.LeafNestedDocuments; |
25 | 25 | import org.elasticsearch.search.NestedDocuments; |
26 | 26 | import org.elasticsearch.search.SearchContextSourcePrinter; |
@@ -122,7 +122,7 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr |
122 | 122 | // - Speed up retrieval of the synthetic source |
123 | 123 | // Note: These vectors will no longer be accessible via _source for any sub-fetch processors, |
124 | 124 | // but they are typically accessed through doc values instead (e.g: re-scorer). |
125 | | - SourceFilter sourceFilter = maybeExcludeNonSemanticTextVectors(context); |
| 125 | + SourceFilter sourceFilter = maybeExcludeNonSemanticTextVectorFields(context); |
126 | 126 | SourceLoader sourceLoader = context.newSourceLoader(sourceFilter); |
127 | 127 | FetchContext fetchContext = new FetchContext(context, sourceLoader); |
128 | 128 |
|
@@ -461,24 +461,53 @@ private static boolean shouldExcludeVectorsFromSource(SearchContext context) { |
461 | 461 | * unless vectors are explicitly requested to be included in the source. |
462 | 462 | * Returns {@code null} when vectors should not be filtered out. |
463 | 463 | */ |
464 | | - private static SourceFilter maybeExcludeNonSemanticTextVectors(SearchContext context) { |
| 464 | + private static SourceFilter maybeExcludeNonSemanticTextVectorFields(SearchContext context) { |
465 | 465 | if (shouldExcludeVectorsFromSource(context) == false) { |
466 | 466 | return null; |
467 | 467 | } |
468 | 468 | var lookup = context.getSearchExecutionContext().getMappingLookup(); |
469 | | - List<String> inferencePatterns = lookup.inferenceFields().isEmpty() |
470 | | - ? null |
471 | | - : lookup.inferenceFields().keySet().stream().map(f -> f + "*").toList(); |
472 | | - var excludes = lookup.getFullNameToFieldType() |
473 | | - .values() |
474 | | - .stream() |
475 | | - .filter( |
476 | | - f -> f instanceof DenseVectorFieldMapper.DenseVectorFieldType || f instanceof SparseVectorFieldMapper.SparseVectorFieldType |
| 469 | + var fetchFieldsAut = context.fetchFieldsContext() != null && context.fetchFieldsContext().fields().size() > 0 |
| 470 | + ? new CharacterRunAutomaton( |
| 471 | + Regex.simpleMatchToAutomaton(context.fetchFieldsContext().fields().stream().map(f -> f.field).toArray(String[]::new)) |
477 | 472 | ) |
| 473 | + : null; |
| 474 | + var inferenceFieldsAut = lookup.inferenceFields().size() > 0 |
| 475 | + ? new CharacterRunAutomaton( |
| 476 | + Regex.simpleMatchToAutomaton(lookup.inferenceFields().keySet().stream().map(f -> f + "*").toArray(String[]::new)) |
| 477 | + ) |
| 478 | + : null; |
| 479 | + |
| 480 | + List<String> lateExcludes = new ArrayList<>(); |
| 481 | + var excludes = lookup.getFullNameToFieldType().values().stream().filter(MappedFieldType::isVectorEmbedding).filter(f -> { |
| 482 | + // Exclude the field specified by the `fields` option |
| 483 | + if (fetchFieldsAut != null && fetchFieldsAut.run(f.name())) { |
| 484 | + lateExcludes.add(f.name()); |
| 485 | + return false; |
| 486 | + } |
478 | 487 | // Exclude vectors from semantic text fields, as they are processed separately |
479 | | - .filter(f -> Regex.simpleMatch(inferencePatterns, f.name()) == false) |
480 | | - .map(f -> f.name()) |
481 | | - .collect(Collectors.toList()); |
| 488 | + return inferenceFieldsAut == null || inferenceFieldsAut.run(f.name()) == false; |
| 489 | + }).map(f -> f.name()).collect(Collectors.toList()); |
| 490 | + |
| 491 | + if (lateExcludes.size() > 0) { |
| 492 | + /** |
| 493 | + * Adds the vector field specified by the `fields` option to the excludes list of the fetch source context. |
| 494 | + * This ensures that vector fields are available to sub-fetch phases, but excluded during the {@link FetchSourcePhase}. |
| 495 | + */ |
| 496 | + if (context.fetchSourceContext() != null && context.fetchSourceContext().excludes() != null) { |
| 497 | + for (var exclude : context.fetchSourceContext().excludes()) { |
| 498 | + lateExcludes.add(exclude); |
| 499 | + } |
| 500 | + } |
| 501 | + var fetchSourceContext = context.fetchSourceContext() == null |
| 502 | + ? FetchSourceContext.of(true, false, null, lateExcludes.toArray(String[]::new)) |
| 503 | + : FetchSourceContext.of( |
| 504 | + context.fetchSourceContext().fetchSource(), |
| 505 | + context.fetchSourceContext().includeVectors(), |
| 506 | + context.fetchSourceContext().includes(), |
| 507 | + lateExcludes.toArray(String[]::new) |
| 508 | + ); |
| 509 | + context.fetchSourceContext(fetchSourceContext); |
| 510 | + } |
482 | 511 | return excludes.isEmpty() ? null : new SourceFilter(new String[] {}, excludes.toArray(String[]::new)); |
483 | 512 | } |
484 | 513 | } |
0 commit comments