Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,28 @@

package org.elasticsearch.compute.operator.lookup;

import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FilterDirectoryReader;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOBooleanSupplier;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.IntBlock;
Expand All @@ -29,6 +42,10 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;

/**
* Lookup document IDs for the input queries.
Expand Down Expand Up @@ -60,11 +77,227 @@ public EnrichQuerySourceOperator(
this.queryList = queryList;
this.shardContext = shardContext;
this.shardContext.incRef();
this.searcher = shardContext.searcher();
this.indexReader = searcher.getIndexReader();
try {
this.indexReader = new CachedDirectoryReader((DirectoryReader) shardContext.searcher().getIndexReader());
this.searcher = new IndexSearcher(this.indexReader);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
this.warnings = warnings;
}

static class CachedDirectoryReader extends FilterDirectoryReader {
CachedDirectoryReader(DirectoryReader in) throws IOException {
super(in, new FilterDirectoryReader.SubReaderWrapper() {
@Override
public LeafReader wrap(LeafReader reader) {
return new CachedLeafReader(reader);
}
});
}

@Override
protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
return new CachedDirectoryReader(in);
}

@Override
public CacheHelper getReaderCacheHelper() {
return in.getReaderCacheHelper();
}
}

static class CachedLeafReader extends FilterLeafReader {
final Map<String, NumericDocValues> docValues = new HashMap<>();
final Map<String, TermsEnum> termEnums = new HashMap<>();

CachedLeafReader(LeafReader in) {
super(in);
}

@Override
public NumericDocValues getNumericDocValues(String field) throws IOException {
NumericDocValues dv = super.getNumericDocValues(field);
if (dv == null) {
return null;
}
return new CachedNumericDocValues(docId -> docValues.compute(field, (k, curr) -> {
if (curr == null || curr.docID() > docId) {
return dv;
}
return curr;
}));
}

@Override
public Terms terms(String field) throws IOException {
Terms terms = super.terms(field);
if (terms == null) {
return null;
}
return new FilterTerms(terms) {
@Override
public TermsEnum iterator() throws IOException {
return new CachedTermsEnum((reuse) -> {
return termEnums.compute(field, (k, curr) -> {
if (curr == null || reuse == false) {
try {
curr = in.iterator();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
return curr;
});
});
}
};
}

@Override
public CacheHelper getCoreCacheHelper() {
return in.getCoreCacheHelper();
}

@Override
public CacheHelper getReaderCacheHelper() {
return in.getCoreCacheHelper();
}
}

static class CachedNumericDocValues extends NumericDocValues {
private NumericDocValues delegate = null;
private final IntFunction<NumericDocValues> fromCache;

CachedNumericDocValues(IntFunction<NumericDocValues> fromCache) {
this.fromCache = fromCache;
}

NumericDocValues getDelegate(int docID) {
if (delegate == null) {
delegate = fromCache.apply(docID);
}
return delegate;
}

@Override
public long longValue() throws IOException {
return getDelegate(-1).longValue();
}

@Override
public boolean advanceExact(int target) throws IOException {
return getDelegate(target).advanceExact(target);
}

@Override
public int advance(int target) throws IOException {
return getDelegate(target).nextDoc();
}

@Override
public int docID() {
return getDelegate(-1).docID();
}

@Override
public int nextDoc() throws IOException {
return getDelegate(-1).nextDoc();
}

@Override
public long cost() {
return fromCache.apply(DocIdSetIterator.NO_MORE_DOCS).cost();
}
}

static class CachedTermsEnum extends TermsEnum {
private TermsEnum delegate = null;
private final Function<Boolean, TermsEnum> fromCache;

CachedTermsEnum(Function<Boolean, TermsEnum> fromCache) {
this.fromCache = fromCache;
}

TermsEnum getDelegate(boolean reuse) {
if (delegate == null) {
delegate = fromCache.apply(reuse);
}
return delegate;
}

@Override
public AttributeSource attributes() {
return getDelegate(false).attributes();
}

@Override
public boolean seekExact(BytesRef text) throws IOException {
return getDelegate(true).seekExact(text);
}

@Override
public IOBooleanSupplier prepareSeekExact(BytesRef text) throws IOException {
return getDelegate(true).prepareSeekExact(text);
}

@Override
public void seekExact(long ord) throws IOException {
getDelegate(true).seekExact(ord);
}

@Override
public void seekExact(BytesRef term, TermState state) throws IOException {
// TODO: when this can be true?
getDelegate(false).seekExact(term, state);
}

@Override
public SeekStatus seekCeil(BytesRef text) throws IOException {
return getDelegate(false).seekCeil(text);
}

@Override
public BytesRef term() throws IOException {
return getDelegate(false).term();
}

@Override
public long ord() throws IOException {
return getDelegate(false).ord();
}

@Override
public int docFreq() throws IOException {
return getDelegate(false).docFreq();
}

@Override
public long totalTermFreq() throws IOException {
return getDelegate(false).totalTermFreq();
}

@Override
public PostingsEnum postings(PostingsEnum reuse, int flags) throws IOException {
return getDelegate(false).postings(reuse, flags);
}

@Override
public ImpactsEnum impacts(int flags) throws IOException {
return getDelegate(false).impacts(flags);
}

@Override
public TermState termState() throws IOException {
return getDelegate(false).termState();
}

@Override
public BytesRef next() throws IOException {
return getDelegate(false).next();
}
}

@Override
public void finish() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

package org.elasticsearch.xpack.esql.enrich;

import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.search.Query;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.operator.Warnings;
import org.elasticsearch.compute.operator.lookup.QueryList;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.index.mapper.DateFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
Expand Down Expand Up @@ -55,13 +60,7 @@ public BinaryComparisonQueryList(
AliasFilter aliasFilter,
Warnings warnings
) {
super(
field,
searchExecutionContext,
aliasFilter,
leftHandSideBlock,
new OnlySingleValueParams(warnings, "LOOKUP JOIN encountered multi-value")
);
super(field, searchExecutionContext, aliasFilter, leftHandSideBlock, null);
// swap left and right if the field is on the right
// We get a filter in the form left_expr >= right_expr
// here we will swap it to right_expr <= left_expr
Expand Down Expand Up @@ -93,6 +92,12 @@ public Query doGetQuery(int position, int firstValueIndex, int valueCount) {
);
try {
if (TranslationAware.Translatable.YES.equals(comparison.translatable(lucenePushdownPredicates))) {
// Check if this is a numeric/doc values field comparison (but not NEQ)
if (isNumericOrDateField(field)
&& comparison.left() instanceof FieldAttribute
&& comparison.getFunctionType() != EsqlBinaryComparison.BinaryComparisonOperation.NEQ) {
return createNumericDocValuesQuery(comparison, value);
}
return comparison.asQuery(lucenePushdownPredicates, TranslatorHandler.TRANSLATOR_HANDLER)
.toQueryBuilder()
.toQuery(searchExecutionContext);
Expand All @@ -103,4 +108,40 @@ public Query doGetQuery(int position, int firstValueIndex, int valueCount) {
throw new UncheckedIOException("Error while building query for join on filter:", e);
}
}

private boolean isNumericOrDateField(MappedFieldType field) {
if (field instanceof DateFieldMapper.DateFieldType) {
return true;
}
if (field instanceof NumberFieldMapper.NumberFieldType numberFieldType) {
// Exclude floating-point types to avoid precision loss in createNumericDocValuesQuery
return numberFieldType.numericType() != IndexNumericFieldData.NumericType.DOUBLE
&& numberFieldType.numericType() != IndexNumericFieldData.NumericType.FLOAT
&& numberFieldType.numericType() != IndexNumericFieldData.NumericType.HALF_FLOAT;
}
return false;
}

private Query createNumericDocValuesQuery(EsqlBinaryComparison comparison, Object value) {
String fieldName = field.name();
Number numericValue = (Number) value;
// Convert the value to long for NumericDocValuesField (works for both numeric and date fields)
long longValue = numericValue.longValue();

// Create range query based on comparison type
switch (comparison.getFunctionType()) {
case GT:
return NumericDocValuesField.newSlowRangeQuery(fieldName, longValue + 1, Long.MAX_VALUE);
case GTE:
return NumericDocValuesField.newSlowRangeQuery(fieldName, longValue, Long.MAX_VALUE);
case LT:
return NumericDocValuesField.newSlowRangeQuery(fieldName, Long.MIN_VALUE, longValue - 1);
case LTE:
return NumericDocValuesField.newSlowRangeQuery(fieldName, Long.MIN_VALUE, longValue);
case EQ:
return NumericDocValuesField.newSlowRangeQuery(fieldName, longValue, longValue);
default:
throw new IllegalArgumentException("Unsupported comparison type: " + comparison.getFunctionType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private void buildJoinOnForExpressionJoin(
aliasFilter,
inputPage.getBlock(matchFields.get(i).channel()),
matchFields.get(i).type()
).onlySingleValues(warnings, "LOOKUP JOIN encountered multi-value");
);
queryLists.add(termQueryForEquals);
} else {
queryLists.add(
Expand Down