|
23 | 23 | import org.apache.lucene.index.FloatVectorValues; |
24 | 24 | import org.apache.lucene.store.IndexInput; |
25 | 25 | import org.apache.lucene.util.Bits; |
| 26 | +import org.apache.lucene.util.MathUtil; |
26 | 27 |
|
27 | 28 | import java.io.IOException; |
28 | 29 | import java.util.Random; |
@@ -81,30 +82,52 @@ public Bits getAcceptOrds(Bits acceptDocs) { |
81 | 82 |
|
82 | 83 | static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { |
83 | 84 | // TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors? |
84 | | - int[] samples = reservoirSample(origin.size(), k, seed); |
85 | | - return new SampleReader(origin, samples.length, i -> samples[i]); |
| 85 | + if (k >= origin.size()) { |
| 86 | + new SampleReader(origin, origin.size(), i -> i); |
| 87 | + } |
| 88 | + Random rnd = new Random(seed); |
| 89 | + RandomLinearCongruentialMapper mapper = new RandomLinearCongruentialMapper(k, origin.size(), rnd); |
| 90 | + return new SampleReader(origin, k, i -> (int) mapper.map(i)); |
86 | 91 | } |
87 | 92 |
|
88 | 93 | /** |
89 | | - * Sample k elements from n elements according to reservoir sampling algorithm. |
90 | | - * |
91 | | - * @param n number of elements |
92 | | - * @param k number of samples |
93 | | - * @param seed random seed |
94 | | - * @return array of k samples |
| 94 | + * RandomLinearCongruentialMapper is used to map a range of integers [1, n] to a range of integers [1, m] |
95 | 95 | */ |
96 | | - public static int[] reservoirSample(int n, int k, long seed) { |
97 | | - Random rnd = new Random(seed); |
98 | | - int[] reservoir = new int[k]; |
99 | | - for (int i = 0; i < k; i++) { |
100 | | - reservoir[i] = i; |
| 96 | + static class RandomLinearCongruentialMapper { |
| 97 | + private final long n; |
| 98 | + private final long m; |
| 99 | + private final long multiplier; |
| 100 | + private final int randomLinearShift; |
| 101 | + |
| 102 | + RandomLinearCongruentialMapper(long smaller, long larger, Random random) { |
| 103 | + assert smaller > 0 && larger > 0; |
| 104 | + assert smaller < larger; |
| 105 | + this.n = smaller; |
| 106 | + this.m = larger; |
| 107 | + this.multiplier = findLargeOddCoprime(n); |
| 108 | + this.randomLinearShift = random.nextInt(0, 1024 * 1024); |
| 109 | + } |
| 110 | + |
| 111 | + // need to ensure positive modulus only |
| 112 | + private static long mod(long x, long m) { |
| 113 | + long r = x % m; |
| 114 | + return r < 0 ? r + m : r; |
101 | 115 | } |
102 | | - for (int i = k; i < n; i++) { |
103 | | - int j = rnd.nextInt(i + 1); |
104 | | - if (j < k) { |
105 | | - reservoir[j] = i; |
| 116 | + |
| 117 | + long map(long i) { |
| 118 | + if (i < 0 || i >= n) { |
| 119 | + throw new IllegalArgumentException("i out of range"); |
| 120 | + } |
| 121 | + long permuted = mod((i * multiplier + randomLinearShift), n); |
| 122 | + return 1 + mod(permuted, m); |
| 123 | + } |
| 124 | + |
| 125 | + private static long findLargeOddCoprime(long n) { |
| 126 | + long candidate = n | 1; // make sure it's odd |
| 127 | + while (MathUtil.gcd(candidate, n) != 1) { |
| 128 | + candidate += 2; |
106 | 129 | } |
| 130 | + return candidate; |
107 | 131 | } |
108 | | - return reservoir; |
109 | 132 | } |
110 | 133 | } |
0 commit comments