Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,37 @@ Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
* @return an optional containing the vector scorer, or empty
*/
Optional<RandomVectorScorer> getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector);

/**
* Returns an optional containing an int7 optimal scalar quantized vector score supplier
* for the given parameters, or an empty optional if a scorer is not supported.
*
* @param similarityType the similarity type
* @param input the index input containing the vector data
* @param values the random access vector values
* @return an optional containing the vector scorer supplier, or empty
*/
Optional<RandomVectorScorerSupplier> getInt7uOSQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values
);

/**
* Returns an optional containing an int7 optimal scalar quantized vector scorer for
* the given parameters, or an empty optional if a scorer is not supported.
*
* @param sim the similarity type
* @param values the random access vector values
* @return an optional containing the vector scorer, or empty
*/
Optional<RandomVectorScorer> getInt7uOSQVectorScorer(
VectorSimilarityFunction sim,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values,
byte[] quantizedQuery,
float lowerInterval,
float upperInterval,
float additionalCorrection,
int quantizedComponentSum
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,26 @@ public Optional<RandomVectorScorer> getInt7SQVectorScorer(
) {
throw new UnsupportedOperationException("should not reach here");
}

@Override
public Optional<RandomVectorScorerSupplier> getInt7uOSQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values
) {
throw new UnsupportedOperationException("should not reach here");
}

@Override
public Optional<RandomVectorScorer> getInt7uOSQVectorScorer(
VectorSimilarityFunction sim,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values,
byte[] quantizedQuery,
float lowerInterval,
float upperInterval,
float additionalCorrection,
int quantizedComponentSum
) {
throw new UnsupportedOperationException("should not reach here");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.elasticsearch.simdvec.internal.FloatVectorScorerSupplier;
import org.elasticsearch.simdvec.internal.Int7SQVectorScorer;
import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier;
import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorer;
import org.elasticsearch.simdvec.internal.Int7uOSQVectorScorerSupplier;

import java.util.Optional;

Expand Down Expand Up @@ -90,6 +92,45 @@ public Optional<RandomVectorScorer> getInt7SQVectorScorer(
return Int7SQVectorScorer.create(sim, values, queryVector);
}

@Override
public Optional<RandomVectorScorerSupplier> getInt7uOSQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values
) {
input = FilterIndexInput.unwrapOnlyTest(input);
if (input instanceof MemorySegmentAccessInput msInput) {
checkInvariants(values.size(), values.dimension(), input);
return switch (similarityType) {
case COSINE, DOT_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.DotProductSupplier(msInput, values));
case EUCLIDEAN -> Optional.of(new Int7uOSQVectorScorerSupplier.EuclideanSupplier(msInput, values));
case MAXIMUM_INNER_PRODUCT -> Optional.of(new Int7uOSQVectorScorerSupplier.MaxInnerProductSupplier(msInput, values));
};
}
return Optional.empty();
}

@Override
public Optional<RandomVectorScorer> getInt7uOSQVectorScorer(
VectorSimilarityFunction sim,
org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues values,
byte[] quantizedQuery,
float lowerInterval,
float upperInterval,
float additionalCorrection,
int quantizedComponentSum
) {
return Int7uOSQVectorScorer.create(
sim,
values,
quantizedQuery,
lowerInterval,
upperInterval,
additionalCorrection,
quantizedComponentSum
);
}

static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.simdvec.internal;

import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

import java.util.Optional;

/**
* Outlines the Int7 OSQ query-time scorer. The concrete implementation will
* connect to the native OSQ routines and apply the similarity-specific
* corrections.
*/
public final class Int7uOSQVectorScorer {

public static Optional<RandomVectorScorer> create(
VectorSimilarityFunction sim,
QuantizedByteVectorValues values,
byte[] quantizedQuery,
float lowerInterval,
float upperInterval,
float additionalCorrection,
int quantizedComponentSum
) {
// TODO add JDK21 fallback logic and native scorer dispatch
return Optional.empty();
}

private Int7uOSQVectorScorer() {}
}
Loading