Skip to content

Commit 95e709e

Browse files
committed
Add maxnorm keyword argument to SignALSH constructor (partial fix to issue #2).
1 parent 0fe0b04 commit 95e709e

File tree

3 files changed

+95
-28
lines changed

3 files changed

+95
-28
lines changed

src/hashes/sign_alsh.jl

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,39 @@ mutable struct SignALSH{T <: Union{Float32,Float64}} <: AsymmetricLSHFunction
2222
P_shift :: Vector{T}
2323
m :: Int64
2424

25+
# An upper bound on the norm of the data points this hash function will
26+
# process
27+
maxnorm :: T
28+
2529
# Whether or not SignALSH should round up to the next power of 2 when
2630
# resizing its coefficient array.
2731
resize_pow2 :: Bool
2832
end
2933

3034
### External SignALSH constructors
31-
function SignALSH{T}(n_hashes::Integer;
32-
m::Integer = 3,
33-
resize_pow2::Bool = false) where {T}
34-
35-
coeff_A = Matrix{T}(undef, n_hashes, 0)
36-
coeff_B = randn(T, n_hashes, m)
37-
P_shift = coeff_B * fill(T(1/2), m)
38-
39-
SignALSH(coeff_A, coeff_B, P_shift, Int64(m), resize_pow2)
35+
@generated function SignALSH{T}(n_hashes::Integer = 1;
36+
maxnorm::Union{Nothing,Real} = nothing,
37+
m::Integer = 3,
38+
resize_pow2::Bool = false) where {T}
39+
40+
if maxnorm <: Nothing
41+
:("maxnorm must be specified for SignALSH" |> ErrorException |> throw)
42+
else
43+
quote
44+
if maxnorm < 0
45+
"maxnorm must be non-negative" |> ErrorException |> throw
46+
elseif m 0
47+
"m must be positive" |> ErrorException |> throw
48+
end
49+
50+
coeff_A = Matrix{T}(undef, n_hashes, 0)
51+
coeff_B = randn(T, n_hashes, m)
52+
P_shift = coeff_B * fill(T(1/2), m)
53+
54+
SignALSH(coeff_A, coeff_B, P_shift, Int64(m),
55+
T(maxnorm), resize_pow2)
56+
end
57+
end
4058
end
4159

4260
SignALSH(args...; dtype=Float32, kws...) =
@@ -86,16 +104,23 @@ function SignALSH_P(hashfn::SignALSH{T}, x::AbstractArray{T}) where {T}
86104
# after dividing through x by the largest norm of all of the columns
87105
# of x.
88106
norms = col_norms(x)
89-
maxnorm = maximum(norms)
90-
maxnorm = maxnorm == 0 ? 1 : maxnorm # To handle some edge cases
91-
norms .*= 1/maxnorm
107+
108+
for norm_ii in norms
109+
if norm_ii > hashfn.maxnorm
110+
"norm $(norm_ii) exceeds hashfn.maxnorm ($(hashfn.maxnorm))" |>
111+
ErrorException |>
112+
throw
113+
end
114+
end
115+
116+
norms .*= 1/hashfn.maxnorm
92117

93118
n = size(x,1)
94119
if n > current_max_input_size(hashfn)
95120
resize!(hashfn, n)
96121
end
97122

98-
Ax = @views hashfn.coeff_A[1:end,1:n] * x .* (1/maxnorm)
123+
Ax = @views hashfn.coeff_A[1:end,1:n] * x .* (1/hashfn.maxnorm)
99124

100125
# Perform the transformation P(x) on x, except that instead of actually
101126
# allocating memory for it and computing it, pile it onto Ax
@@ -157,6 +182,15 @@ function SignALSH_Q(hashfn::SignALSH{T}, x::AbstractArray{T}) where {T}
157182

158183
Ax = @views hashfn.coeff_A[1:end,1:n] * x
159184
norms = col_norms(x)
185+
186+
for norm_ii in norms
187+
if norm_ii > hashfn.maxnorm
188+
"norm $(norm_ii) exceeds hashfn.maxnorm ($(hashfn.maxnorm))" |>
189+
ErrorException |>
190+
throw
191+
end
192+
end
193+
160194
map!(inv, norms, norms)
161195
@. Ax * norms' T(0)
162196
end

test/hashes/test_sign_alsh.jl

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,28 @@ Tests
88
@testset "SignALSH tests" begin
99
Random.seed!(RANDOM_SEED)
1010

11+
@testset "Construct SignALSH" begin
12+
hashfn = SignALSH(; maxnorm=1)
13+
@test n_hashes(hashfn) == 1
14+
@test isa(hashfn, SignALSH{Float32})
15+
@test isa(hashfn, LSH.AsymmetricLSHFunction)
16+
@test hashtype(hashfn) == BitArray{1}
17+
18+
hashfn = SignALSH(32; maxnorm=1)
19+
@test n_hashes(hashfn) == 32
20+
21+
hashfn = SignALSH(; dtype=Float64, maxnorm=1)
22+
@test isa(hashfn, SignALSH{Float64})
23+
24+
# maxnorm must be specified and non-negative
25+
@test_throws ErrorException SignALSH()
26+
@test_throws ErrorException SignALSH(; maxnorm=-1)
27+
28+
# m must be positive
29+
@test_throws ErrorException SignALSH(; m=-1)
30+
@test_throws ErrorException SignALSH(; m=0)
31+
end
32+
1133
@test_skip @testset "Can hash inputs correctly with SignALSH" begin
1234
input_length = 5
1335
n_hashes = 8
@@ -45,22 +67,33 @@ Tests
4567
@test simhash(Qx) == qhashes
4668
end
4769

48-
@testset "SignALSH generates collisions for large inner products" begin
49-
input_length = 5; n_hashes = 128;
50-
hashfn = SignALSH(n_hashes)
70+
@testset "SignALSH can't hash inputs of norm > maxnorm" begin
71+
hashfn = SignALSH(; maxnorm=0)
72+
@test_throws ErrorException index_hash(hashfn, rand(4))
73+
@test_throws ErrorException query_hash(hashfn, rand(4))
5174

52-
x = randn(input_length)
53-
x_query_hashes = query_hash(hashfn, x)
75+
# Should have no issue if norm(x) == maxnorm
76+
@test index_hash(hashfn, zeros(4)) |> length == 1
77+
@test query_hash(hashfn, zeros(4)) |> length == 1
78+
end
5479

55-
# Check that SignALSH isn't just generating a single query hash
56-
@test any(x_query_hashes .!= x_query_hashes[1])
80+
@testset "SignALSH generates collisions for large inner products" begin
81+
input_length = 5; n_hashes = 128;
5782

58-
# Compute the indexing hashes for a dataset with four vectors:
59-
# a) 10 * x (where x is the test query vector)
83+
# Compare a random vector x against four other vectors:
84+
# a) 10 * x
6085
# b) x
6186
# c) A vector of all zeros
6287
# d) -x
63-
dataset = [(10*x) x zeros(input_length) -x]
88+
x = randn(input_length)
89+
x2, x3, x4 = 10*x, zero(x), -x
90+
91+
maxnorm = (x, x2, x3, x4) .|> norm |> maximum
92+
hashfn = SignALSH(n_hashes; maxnorm=maxnorm)
93+
94+
x_query_hashes = query_hash(hashfn, x)
95+
96+
dataset = [x2 x x3 x4]
6497
p_hashes = index_hash(hashfn, dataset)
6598

6699
# Each collection of hashes should be different from one another
@@ -85,7 +118,7 @@ Tests
85118
n_inputs = 150
86119
n_hashes = 2
87120

88-
hashfn = SignALSH(n_hashes)
121+
hashfn = SignALSH(n_hashes; maxnorm=4*input_size)
89122
x = sprandn(input_size, n_inputs, 0.2)
90123

91124
# Mostly just need to test that the following lines don't crash
@@ -100,7 +133,7 @@ Tests
100133
input_size = 100
101134
n_inputs = 150
102135
n_hashes = 2
103-
hashfn = SignALSH(n_hashes)
136+
hashfn = SignALSH(n_hashes; maxnorm=4*input_size)
104137

105138
## Test 1: regular matrix adjoint
106139
x = randn(n_inputs, input_size)'
@@ -115,7 +148,7 @@ Tests
115148

116149
@testset "Hash inputs of different sizes" begin
117150
n_hashes = 42
118-
hashfn = SignALSH(n_hashes)
151+
hashfn = SignALSH(n_hashes; maxnorm=100)
119152

120153
@test size(hashfn.coeff_A) == (n_hashes, 0)
121154

@@ -140,7 +173,7 @@ Tests
140173

141174
@testset "Hash inputs of different sizes with resize_pow2 = true" begin
142175
n_hashes = 25
143-
hashfn = SignALSH(n_hashes; resize_pow2=true)
176+
hashfn = SignALSH(n_hashes; maxnorm=100, resize_pow2=true)
144177

145178
@test size(hashfn.coeff_A) == (n_hashes, 0)
146179

test/tables/test_table.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ using Test, Random, LSH
127127
n_hashes = 8
128128

129129
hashfn_mips = MIPSHash(n_hashes)
130-
hashfn_sign = SignALSH(n_hashes)
130+
hashfn_sign = SignALSH(n_hashes; maxnorm=input_size)
131131

132132
for hashfn in (hashfn_mips, hashfn_sign)
133133
table = LSHTable(hashfn)

0 commit comments

Comments
 (0)