Skip to content

Commit 3c21268

Browse files
committed
adjust sample reader to use a LCG mapping
1 parent fe9e86e commit 3c21268

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.index.FloatVectorValues;
2424
import org.apache.lucene.store.IndexInput;
2525
import org.apache.lucene.util.Bits;
26+
import org.apache.lucene.util.MathUtil;
2627

2728
import java.io.IOException;
2829
import java.util.Random;
@@ -81,30 +82,52 @@ public Bits getAcceptOrds(Bits acceptDocs) {
8182

8283
static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
8384
// TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors?
84-
int[] samples = reservoirSample(origin.size(), k, seed);
85-
return new SampleReader(origin, samples.length, i -> samples[i]);
85+
if (k >= origin.size()) {
86+
new SampleReader(origin, origin.size(), i -> i);
87+
}
88+
Random rnd = new Random(seed);
89+
RandomLinearCongruentialMapper mapper = new RandomLinearCongruentialMapper(k, origin.size(), rnd);
90+
return new SampleReader(origin, k, i -> (int) mapper.map(i));
8691
}
8792

8893
/**
89-
* Sample k elements from n elements according to reservoir sampling algorithm.
90-
*
91-
* @param n number of elements
92-
* @param k number of samples
93-
* @param seed random seed
94-
* @return array of k samples
94+
* RandomLinearCongruentialMapper is used to map a range of integers [1, n] to a range of integers [1, m]
9595
*/
96-
public static int[] reservoirSample(int n, int k, long seed) {
97-
Random rnd = new Random(seed);
98-
int[] reservoir = new int[k];
99-
for (int i = 0; i < k; i++) {
100-
reservoir[i] = i;
96+
static class RandomLinearCongruentialMapper {
97+
private final long n;
98+
private final long m;
99+
private final long multiplier;
100+
private final int randomLinearShift;
101+
102+
RandomLinearCongruentialMapper(long smaller, long larger, Random random) {
103+
assert smaller > 0 && larger > 0;
104+
assert smaller < larger;
105+
this.n = smaller;
106+
this.m = larger;
107+
this.multiplier = findLargeOddCoprime(n);
108+
this.randomLinearShift = random.nextInt(0, 1024 * 1024);
109+
}
110+
111+
// need to ensure positive modulus only
112+
private static long mod(long x, long m) {
113+
long r = x % m;
114+
return r < 0 ? r + m : r;
101115
}
102-
for (int i = k; i < n; i++) {
103-
int j = rnd.nextInt(i + 1);
104-
if (j < k) {
105-
reservoir[j] = i;
116+
117+
long map(long i) {
118+
if (i < 0 || i >= n) {
119+
throw new IllegalArgumentException("i out of range");
120+
}
121+
long permuted = mod((i * multiplier + randomLinearShift), n);
122+
return 1 + mod(permuted, m);
123+
}
124+
125+
private static long findLargeOddCoprime(long n) {
126+
long candidate = n | 1; // make sure it's odd
127+
while (MathUtil.gcd(candidate, n) != 1) {
128+
candidate += 2;
106129
}
130+
return candidate;
107131
}
108-
return reservoir;
109132
}
110133
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.util.FixedBitSet;
13+
import org.elasticsearch.test.ESTestCase;
14+
15+
public class SampleReaderTests extends ESTestCase {
16+
17+
public void testRandomSampling() {
18+
int randomLongLower = randomIntBetween(0, 1024 * 10);
19+
int randomLongUpper = randomIntBetween(randomLongLower, 1024 * 100);
20+
SampleReader.RandomLinearCongruentialMapper mapper = new SampleReader.RandomLinearCongruentialMapper(
21+
randomLongLower,
22+
randomLongUpper,
23+
random()
24+
);
25+
FixedBitSet valueSeen = new FixedBitSet(randomLongUpper + 1);
26+
for (int i = 0; i < randomLongLower; i++) {
27+
long mapped = mapper.map(i);
28+
assertTrue(mapped >= 0);
29+
assertTrue(mapped <= randomLongUpper);
30+
assertFalse(valueSeen.getAndSet((int) mapped));
31+
}
32+
}
33+
34+
}

0 commit comments

Comments
 (0)