Skip to content

Commit f7ed395

Browse files
committed
Fixes problem where different samplers have the same name
1 parent aeab3fd commit f7ed395

File tree

5 files changed

+79
-39
lines changed

5 files changed

+79
-39
lines changed

benchmarks/src/Benchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Benchmarks
22

33
export BenchmarkCondition, generate_conditions, is_compatible, construct_samplers
4-
export setup_sampler, benchmark_step!, benchmark_config
4+
export setup_sampler, benchmark_step!, benchmark_config, sampler_name
55
export N_ENABLED, N_CHANGES, DISTRIBUTIONS, KEY_STRATEGIES
66

77
include("conditions.jl")

benchmarks/src/benchmark_runner.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ function main()
3737
continue
3838
end
3939

40-
println("[$done/$total] Running: $(nameof(sampler_type)), n_enabled=$(cond.n_enabled), n_changes=$(cond.n_changes), dist=$(cond.distributions), keys=$(cond.key_strategy)")
40+
println("[$done/$total] Running: $(sampler_name(sampler_type)), n_enabled=$(cond.n_enabled), n_changes=$(cond.n_changes), dist=$(cond.distributions), keys=$(cond.key_strategy)")
4141

4242
time_ns, mem_bytes = benchmark_config(sampler, cond)
4343

4444
push!(results, (
45-
string(nameof(sampler_type)),
45+
sampler_name(sampler_type),
4646
cond.n_enabled,
4747
cond.n_changes,
4848
string(cond.distributions),

benchmarks/src/conditions.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,9 @@ Generate all valid benchmark conditions.
3737
Filters out invalid combinations where n_changes > n_enabled.
3838
"""
3939
function generate_conditions()
40-
conditions = BenchmarkCondition[]
41-
42-
for n in N_ENABLED, m in N_CHANGES, d in DISTRIBUTIONS, k in KEY_STRATEGIES
43-
# Skip invalid conditions where we try to change more clocks than exist
44-
if m < n
45-
push!(conditions, BenchmarkCondition(n, m, d, k))
46-
end
47-
push!(conditions, BenchmarkCondition(n, n, d, k))
48-
end
49-
50-
return conditions
40+
[BenchmarkCondition(n, m, d, k)
41+
for n in N_ENABLED, m in N_CHANGES, d in DISTRIBUTIONS, k in KEY_STRATEGIES
42+
if m <= n]
5143
end
5244

5345
"""
@@ -56,8 +48,10 @@ Check if a sampler type is compatible with a given condition.
5648
DirectCall only works with exponential distributions.
5749
"""
5850
function is_compatible(sampler_type::Type, cond::BenchmarkCondition)
59-
# DirectCall can only handle exponential distributions
60-
if nameof(sampler_type) == :DirectCall && cond.distributions != :exponential
51+
# DirectCall variants can only handle exponential distributions
52+
# Check type name starts with "DirectCall" to catch all variants
53+
type_name = string(nameof(sampler_type))
54+
if startswith(type_name, "DirectCall") && cond.distributions != :exponential
6155
return false
6256
end
6357

benchmarks/src/measure.jl

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,42 @@ using Random
66
# Import from Base.Iterators
77
import Base.Iterators: take, flatten
88

9+
"""
10+
Generate a descriptive, unique name for a sampler type.
11+
12+
For DirectCall variants, extracts the Keep/Removal strategy and PrefixSearch type.
13+
For other samplers, returns the simple type name.
14+
"""
15+
function sampler_name(sampler_type::Type)
16+
type_str = string(sampler_type)
17+
18+
# For DirectCall, extract the variant information
19+
if occursin("DirectCall", type_str)
20+
# Extract KeyedRemovalPrefixSearch vs KeyedKeepPrefixSearch
21+
keep_strategy = if occursin("KeyedRemoval", type_str)
22+
"Removal"
23+
elseif occursin("KeyedKeep", type_str)
24+
"Keep"
25+
else
26+
"Unknown"
27+
end
28+
29+
# Extract BinaryTree vs CumSum
30+
prefix_type = if occursin("BinaryTree", type_str)
31+
"BinaryTree"
32+
elseif occursin("CumSum", type_str)
33+
"CumSum"
34+
else
35+
"Unknown"
36+
end
37+
38+
return "DirectCall_$(keep_strategy)_$(prefix_type)"
39+
end
40+
41+
# For other samplers, use simple name
42+
return string(nameof(sampler_type))
43+
end
44+
945
"""
1046
Create a distribution instance based on the distribution type symbol.
1147
"""
@@ -70,24 +106,33 @@ function benchmark_step!(sampler, enabled_keys, key_strategy, n_changes, dist_ty
70106
end
71107

72108
# 4. Determine which keys to re-enable
73-
if key_strategy == :dense
74-
all_keys_to_enable = flatten([[what_fire], disable_keys])
109+
# 4. Determine which keys to re-enable (always as Vector for type stability)
110+
all_keys_to_enable = if key_strategy == :dense
111+
vcat([what_fire], disable_keys) # Concrete Vector{Int}
75112
elseif key_strategy == :sparse
76113
# Generate brand new random keys for sparse strategy
77-
all_keys_to_enable = Int[]
78-
sizehint!(all_keys_to_enable, n_changes)
79-
while length(all_keys_to_enable) < n_changes
80-
batch = rand(rng, 1:typemax(Int32), 16)
81-
valid_keys = filter(k -> k enabled_keys && k all_keys_to_enable, batch)
82-
needed = n_changes - length(all_keys_to_enable)
83-
append!(all_keys_to_enable, first(valid_keys, min(needed, length(valid_keys))))
114+
new_keys = Int[]
115+
sizehint!(new_keys, n_changes)
116+
seen = Set{Int}(enabled_keys) # O(1) lookups
117+
118+
while length(new_keys) < n_changes
119+
batch_size = min(32, (n_changes - length(new_keys)) * 2)
120+
batch = rand(rng, 1:typemax(Int), batch_size)
121+
for k in batch
122+
if k seen
123+
push!(new_keys, k)
124+
push!(seen, k)
125+
length(new_keys) == n_changes && break
126+
end
127+
end
84128
end
129+
new_keys
85130
else
86131
error("Unknown key strategy: $key_strategy")
87-
end
132+
end::Vector{Int} # Type assertion
88133

89134
# 5. Re-enable all n_changes clocks with new distributions
90-
for k in first(all_keys_to_enable, n_changes)
135+
for k in all_keys_to_enable
91136
dist = create_distribution(dist_type, rng)
92137
enable!(sampler, k, dist, when_fire, when_fire, rng)
93138
push!(enabled_keys, k)
@@ -102,15 +147,16 @@ Run a complete benchmark for a given sampler type and condition.
102147
Returns (median_time_ns, median_memory_bytes).
103148
"""
104149
function benchmark_config(sampler, cond::BenchmarkCondition)
105-
# Set up the sampler
106-
enabled_keys, rng, dist_type = setup_sampler(sampler, cond)
107-
when = 0.0
108-
109-
# Run the benchmark
110-
result = @benchmark benchmark_step!(
111-
$sampler, $enabled_keys, $(cond.key_strategy), $(cond.n_changes),
112-
$dist_type, $rng, $when
113-
) samples=100
150+
# Run the benchmark with proper setup between samples
151+
result = @benchmark begin
152+
benchmark_step!(
153+
$sampler, enabled_keys, $(cond.key_strategy), $(cond.n_changes),
154+
dist_type, rng, 0.0
155+
)
156+
end setup=begin
157+
reset!($sampler)
158+
enabled_keys, rng, dist_type = setup_sampler($sampler, $cond)
159+
end samples=100
114160

115161
return Int(round(median(result.times))), Int(round(median(result.memory)))
116162
end

benchmarks/src/run_all.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function main()
2929
println()
3030
println("Sampler types:")
3131
for s in samplers
32-
println("$(nameof(typeof(s)))")
32+
println("$(sampler_name(typeof(s)))")
3333
end
3434
println()
3535
println("Configuration ranges:")
@@ -68,7 +68,7 @@ function main()
6868
continue
6969
end
7070

71-
print("[$done/$total_combinations] $(nameof(sampler_type)): ")
71+
print("[$done/$total_combinations] $(sampler_name(sampler_type)): ")
7272
print("n=$(cond.n_enabled), churn=$(cond.n_changes), ")
7373
print("dist=$(cond.distributions), keys=$(cond.key_strategy)... ")
7474
flush(stdout)
@@ -77,7 +77,7 @@ function main()
7777
time_ns, mem_bytes = benchmark_config(sampler, cond)
7878

7979
push!(results, (
80-
string(nameof(sampler_type)),
80+
sampler_name(sampler_type),
8181
cond.n_enabled,
8282
cond.n_changes,
8383
string(cond.distributions),

0 commit comments

Comments
 (0)