Skip to content

Commit 2c67ff9

Browse files
committed
feat internal: weighted random sampler util
Enables easy tweaking of probabilities for indidual mutation functions in the future.
1 parent 3e86f6e commit 2c67ff9

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2025 Code Intelligence GmbH
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.code_intelligence.jazzer.mutation.combinator;
18+
19+
import static com.code_intelligence.jazzer.mutation.support.Preconditions.require;
20+
import static java.util.Objects.requireNonNull;
21+
22+
import com.code_intelligence.jazzer.mutation.api.PseudoRandom;
23+
import java.util.Arrays;
24+
import java.util.List;
25+
import java.util.Optional;
26+
import java.util.function.Function;
27+
import java.util.stream.Collectors;
28+
29+
public final class SamplingUtils {
30+
31+
public static <T> Function<PseudoRandom, T> weightedSampler(T[] values, double[] weights) {
32+
// Use Vose's alias method for O(1) sampling after O(n) preprocessing.
33+
requireNonNull(values, "Values must not be null");
34+
requireNonNull(weights, "Weights must not be null");
35+
require(values.length > 0, "Values must not be empty");
36+
require(values.length == weights.length, "Values and weights must have the same length");
37+
38+
double sum = Arrays.stream(weights).sum();
39+
require(sum > 0, "At least one weight must be positive");
40+
41+
int n = values.length;
42+
int[] alias = new int[n];
43+
double[] probability = new double[n];
44+
double[] scaledWeights = Arrays.stream(weights).map(w -> w * n / sum).toArray();
45+
int[] small = new int[n];
46+
int[] large = new int[n];
47+
int smallCount = 0;
48+
int largeCount = 0;
49+
for (int i = 0; i < n; i++) {
50+
if (scaledWeights[i] < 1.0) {
51+
small[smallCount++] = i;
52+
} else {
53+
large[largeCount++] = i;
54+
}
55+
}
56+
57+
while (smallCount > 0 && largeCount > 0) {
58+
int less = small[--smallCount];
59+
int more = large[--largeCount];
60+
61+
probability[less] = scaledWeights[less];
62+
alias[less] = more;
63+
scaledWeights[more] = (scaledWeights[more] + scaledWeights[less]) - 1.0;
64+
65+
if (scaledWeights[more] < 1.0) {
66+
small[smallCount++] = more;
67+
} else {
68+
large[largeCount++] = more;
69+
}
70+
}
71+
while (largeCount > 0) {
72+
probability[large[--largeCount]] = 1.0;
73+
}
74+
75+
while (smallCount > 0) {
76+
probability[small[--smallCount]] = 1.0;
77+
}
78+
return (PseudoRandom random) -> {
79+
int column = random.indexIn(n);
80+
return values[random.closedRange(0.0, 1.0) < probability[column] ? column : alias[column]];
81+
};
82+
}
83+
84+
public static <T> Function<PseudoRandom, T> weightedSampler(
85+
List<WeightedMutationFunction<T>> weightedFunctions) {
86+
requireNonNull(weightedFunctions, "Weighted functions must not be null");
87+
require(!weightedFunctions.isEmpty(), "Weighted functions must not be empty");
88+
89+
double[] weights = weightedFunctions.stream().mapToDouble(m -> m.weight).toArray();
90+
91+
T[] fns = (T[]) weightedFunctions.stream().map(m -> m.fn).toArray(Object[]::new);
92+
93+
return weightedSampler(fns, weights);
94+
}
95+
96+
@SafeVarargs
97+
public static <T> Function<PseudoRandom, T> weightedSampler(
98+
Optional<WeightedMutationFunction<T>>... values) {
99+
return weightedSampler(
100+
Arrays.stream(values)
101+
.filter(Optional::isPresent)
102+
.map(Optional::get)
103+
.collect(Collectors.toList()));
104+
}
105+
106+
/**
107+
* A simple struct to hold a mutation function and its weight. It is here just for stylistic
108+
* reasons, to make the definitions of weights and functions more readable.
109+
*/
110+
public static class WeightedMutationFunction<T> {
111+
public final double weight;
112+
public final T fn;
113+
114+
public WeightedMutationFunction(double weight, T fn) {
115+
this.fn = fn;
116+
this.weight = weight;
117+
}
118+
119+
public static <T> WeightedMutationFunction<T> of(double weight, T fn) {
120+
return new WeightedMutationFunction<>(weight, fn);
121+
}
122+
123+
public static <T> Optional<WeightedMutationFunction<T>> ofOptional(double weight, T fn) {
124+
return Optional.of(new WeightedMutationFunction<>(weight, fn));
125+
}
126+
}
127+
}

src/test/java/com/code_intelligence/jazzer/mutation/combinator/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ java_test_suite(
88
deps = [
99
"//src/main/java/com/code_intelligence/jazzer/mutation/api",
1010
"//src/main/java/com/code_intelligence/jazzer/mutation/combinator",
11+
"//src/main/java/com/code_intelligence/jazzer/mutation/engine",
1112
"//src/main/java/com/code_intelligence/jazzer/mutation/support",
1213
"//src/test/java/com/code_intelligence/jazzer/mutation/support:test_support",
1314
],
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2025 Code Intelligence GmbH
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.code_intelligence.jazzer.mutation.combinator;
18+
19+
import static org.junit.jupiter.params.provider.Arguments.arguments;
20+
21+
import com.code_intelligence.jazzer.mutation.api.PseudoRandom;
22+
import com.code_intelligence.jazzer.mutation.engine.SeededPseudoRandom;
23+
import java.util.function.Function;
24+
import java.util.stream.IntStream;
25+
import java.util.stream.Stream;
26+
import org.junit.jupiter.params.ParameterizedTest;
27+
import org.junit.jupiter.params.provider.MethodSource;
28+
29+
public class SamplingUtilsTest {
30+
static Stream<?> weightsProvider() {
31+
final int N = 1000000;
32+
final double T = 0.03;
33+
return Stream.of(
34+
arguments(N, T, new double[] {1.0, 1.0, 1.0}),
35+
arguments(N, T, new double[] {1.0, 2.0, 3.0, 4.0, 5.0}),
36+
arguments(N, T, new double[] {0.1, 0.2, 0.3, 0.4}),
37+
arguments(N, T, new double[] {10.0, 0.0, 0.1, 0.0, 90.0}),
38+
arguments(N, T, new double[] {5.0, 5.0, 0.0, 0.0, 0.01, 5.0, 5.0}),
39+
arguments(N, T, new double[] {0.0, 0.0, 0.0, 1.0}),
40+
arguments(N, T, new double[] {1.0}),
41+
arguments(N, T, new double[] {0.01, 0.01, 0.01, 0.97}),
42+
arguments(N, T, new double[] {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}),
43+
arguments(N, T, new double[] {0.001, 0.002, 0.003, 0.004, 0.005}),
44+
arguments(N, T, new double[] {0.001, 0.002, 0.003, 0.004, 0.000001, 10.0}),
45+
arguments(N, T, new double[] {0.001, 1000.0, 0.003, 10000.0, 0.005}),
46+
arguments(N, T, IntStream.range(1, 10).mapToDouble(i -> i).toArray()),
47+
arguments(N, 0.09, IntStream.range(1, 100).mapToDouble(i -> 1.0).toArray()),
48+
arguments(N, 0.15, IntStream.range(1, 1000).mapToDouble(i -> 1.0).toArray()),
49+
arguments(10000000, 0.15, IntStream.range(1, 10000).mapToDouble(i -> 1.0).toArray()),
50+
arguments(100000000, 0.16, IntStream.range(1, 100000).mapToDouble(i -> 1.0).toArray()));
51+
}
52+
53+
@ParameterizedTest
54+
@MethodSource("weightsProvider")
55+
public void testWeightedSampler(int trials, double tolerance, double[] weights) {
56+
Integer[] indices = IntStream.range(0, weights.length).boxed().toArray(Integer[]::new);
57+
Function<PseudoRandom, Integer> sampler = SamplingUtils.weightedSampler(indices, weights);
58+
59+
PseudoRandom random = new SeededPseudoRandom(12345);
60+
int[] counts = new int[indices.length];
61+
for (int i = 0; i < trials; i++) {
62+
counts[sampler.apply(random)]++;
63+
}
64+
65+
// Calculate expected probabilities that are proportional to the weights.
66+
double[] pExpected = new double[weights.length];
67+
double sum = 0.0;
68+
for (double w : weights) {
69+
sum += w;
70+
}
71+
for (int i = 0; i < weights.length; i++) {
72+
pExpected[i] = weights[i] / sum;
73+
}
74+
75+
double tol = (double) trials / weights.length * tolerance; // 5% of expected count
76+
// Ensure that the frequencies are within 5% of the expected frequencies.
77+
for (int i = 0; i < weights.length; i++) {
78+
double expectedCount = trials * pExpected[i];
79+
assert Math.abs(counts[i] - expectedCount) < tol
80+
: String.format(
81+
"Count for index %d out of tolerance: got %d, expected ~%.2f",
82+
i, counts[i], expectedCount);
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)