Skip to content

Commit c334620

Browse files
committed
Add maxnorm keyword argument to MIPSHash (second commit to fix issue #2).
1 parent 95e709e commit c334620

File tree

4 files changed

+116
-87
lines changed

4 files changed

+116
-87
lines changed

src/hashes/mips_hash.jl

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,60 @@ mutable struct MIPSHash{T <: Union{Float32,Float64}} <: AsymmetricLSHFunction
2121
Qshift :: Vector{T}
2222
m :: Int64
2323

24+
# An upper bound on the norm of the data points this hash function will
25+
# process
26+
maxnorm :: T
27+
2428
# Whether or not the number of coefficients per hash function should be
2529
# expanded to be a power of 2 whenever we need to resize coeff_A.
2630
resize_pow2 :: Bool
2731

2832
### Internal MIPSHash constructors
29-
function MIPSHash{T}(
30-
n_hashes::Integer = 1;
31-
scale::Real = 1,
32-
m::Integer = 3,
33-
resize_pow2::Bool = false) where {T <: Union{Float32,Float64}}
34-
35-
if n_hashes < 1
36-
"n_hashes must be positive" |> ErrorException |> throw
37-
elseif scale 0
38-
"scaling factor `scale` must be positive" |> ErrorException |> throw
39-
elseif m 0
40-
"m must be positive" |> ErrorException |> throw
41-
end
42-
43-
coeff_A = Matrix{T}(undef, n_hashes, 0)
44-
coeff_B = randn(T, n_hashes, m)
45-
scale = T(scale)
46-
m = Int64(m)
47-
shift = rand(T, n_hashes)
48-
Qshift = coeff_B * fill(T(1/2), m) ./ scale .+ shift
49-
50-
new{T}(coeff_A, coeff_B, scale, shift, Qshift, m, resize_pow2)
51-
end
5233
end
5334

5435
### External MIPSHash constructors
36+
@generated function MIPSHash{T}(n_hashes::Integer = 1;
37+
maxnorm::Union{Nothing,Real} = nothing,
38+
scale::Real = 1,
39+
m::Integer = 3,
40+
resize_pow2::Bool = false) where T
41+
if maxnorm <: Nothing
42+
:("maxnorm must be specified for MIPSHash" |>
43+
ErrorException |>
44+
throw)
45+
else
46+
quote
47+
if n_hashes < 1
48+
"n_hashes must be positive" |>
49+
ErrorException |>
50+
throw
51+
elseif scale 0
52+
"scaling factor `scale` must be positive" |>
53+
ErrorException |>
54+
throw
55+
elseif m 0
56+
"m must be positive" |>
57+
ErrorException |>
58+
throw
59+
elseif maxnorm 0
60+
"maxnorm must be positive" |>
61+
ErrorException |>
62+
throw
63+
end
64+
65+
coeff_A = Matrix{T}(undef, n_hashes, 0)
66+
coeff_B = randn(T, n_hashes, m)
67+
scale = T(scale)
68+
m = Int64(m)
69+
shift = rand(T, n_hashes)
70+
Qshift = coeff_B * fill(T(1/2), m) ./ scale .+ shift
71+
72+
MIPSHash{T}(coeff_A, coeff_B, scale, shift, Qshift, m,
73+
maxnorm, resize_pow2)
74+
end
75+
end
76+
end
77+
5578

5679
MIPSHash(args...; dtype=Float32, kws...) =
5780
MIPSHash{dtype}(args...; kws...)
@@ -108,20 +131,18 @@ h(P(x)) definitions
108131
end
109132
end
110133

111-
function _MIPSHash_P(h :: MIPSHash{T}, x :: AbstractArray) where {T}
134+
function _MIPSHash_P(hashfn::MIPSHash{T}, x::AbstractArray) where {T}
112135
n = size(x,1)
113-
if n > current_max_input_size(h)
114-
resize!(h, size(x,1))
136+
if n > current_max_input_size(hashfn)
137+
resize!(hashfn, size(x,1))
115138
end
116139

117140
norms = col_norms(x)
118-
maxnorm = maximum(norms)
119-
maxnorm = maxnorm == 0 ? 1 : maxnorm # To handle some edge cases
120-
BLAS.scal!(length(norms), 1/maxnorm, norms, 1)
141+
BLAS.scal!(length(norms), 1/hashfn.maxnorm, norms, 1)
121142

122143
# First, perform a matvec on x and the first array of coefficients.
123144
# Note: aTx is an n_hashes × n_inputs array
124-
@views aTx = h.coeff_A[1:end,1:n] * x .* (1/maxnorm) |> mat
145+
@views aTx = hashfn.coeff_A[1:end,1:n] * x .* (1/hashfn.maxnorm) |> mat
125146

126147
# Compute norms^2, norms^4, ... norms^(2^m).
127148
# Multiply these by the second array of coefficients and add them to aTx, so
@@ -135,13 +156,13 @@ function _MIPSHash_P(h :: MIPSHash{T}, x :: AbstractArray) where {T}
135156
# concatenations.
136157
# Note that m is typically small, so these iterations don't do much to harm
137158
# performance
138-
for ii = 1:h.m
159+
for ii = 1:hashfn.m
139160
norms .^= 2
140-
MIPSHash_P_update_aTx!(h.coeff_B[:,ii], norms, aTx)
161+
MIPSHash_P_update_aTx!(hashfn.coeff_B[:,ii], norms, aTx)
141162
end
142163

143164
# Compute the remainder of the hash the same way we'd compute an L^p distance LSH.
144-
@. aTx = aTx / h.scale + h.shift
165+
@. aTx = aTx / hashfn.scale + hashfn.shift
145166

146167
return floor.(Int32, aTx)
147168
end
@@ -170,7 +191,7 @@ h(Q(x)) definitions
170191
end
171192
end
172193

173-
function _MIPSHash_Q(hashfn::MIPSHash, x::AbstractArray)
194+
function _MIPSHash_Q(hashfn::MIPSHash{T}, x::AbstractArray) where T
174195
n = size(x,1)
175196
if n > current_max_input_size(hashfn)
176197
resize!(hashfn, n)
@@ -184,10 +205,8 @@ function _MIPSHash_Q(hashfn::MIPSHash, x::AbstractArray)
184205
# aTx (rather than before) so that we don't have to allocate a new array
185206
# of size(x). Moreover, for large input vectors, the size of aTx is typically
186207
# much smaller than the size of x.
187-
f(x::T) where {T} = (x T(0) ? T(1) : x)
188208
norms = col_norms(x)
189-
map!(f, norms, norms)
190-
209+
map!(x::T -> x T(0) ? T(1) : x, norms, norms)
191210
aTx .= aTx ./ norms'
192211

193212
# Here, we would multiply the second array of coefficients by the elements

test/hashes/test_mips_hash.jl

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,48 @@ Tests
1212
import SparseArrays: sprandn
1313

1414
@testset "Can construct a simple MIPSHash" begin
15-
hashfn = MIPSHash()
15+
hashfn = MIPSHash(; maxnorm=1)
1616

1717
@test n_hashes(hashfn) == 1
1818
@test hashtype(hashfn) == Vector{Int32}
1919
@test isa(hashfn, MIPSHash{Float32}) # Default dtype should be Float32
2020
@test isa(hashfn, LSH.AsymmetricLSHFunction)
2121

2222
##
23-
hashfn = MIPSHash(12)
23+
hashfn = MIPSHash(12; maxnorm=1)
2424

2525
@test n_hashes(hashfn) == 12
2626

2727
##
28-
hashfn = MIPSHash(; dtype=Float64)
28+
hashfn = MIPSHash(; dtype=Float64, maxnorm=1)
2929

3030
@test isa(hashfn, MIPSHash{Float64})
3131

3232
##
33-
hashfn = MIPSHash{Float64}()
33+
hashfn = MIPSHash{Float64}(; maxnorm=1)
3434
@test isa(hashfn, MIPSHash{Float64})
3535

3636
### Invalid hash function construction
37-
38-
@test_throws ErrorException MIPSHash(-1)
39-
@test_throws ErrorException MIPSHash(; m=-1)
40-
@test_throws ErrorException MIPSHash(; m=0)
41-
@test_throws ErrorException MIPSHash(; scale=-1)
42-
@test_throws ErrorException MIPSHash(; scale=0)
37+
# Non-positive number of hash functions
38+
@test_throws ErrorException MIPSHash(-1; maxnorm=1)
39+
@test_throws ErrorException MIPSHash( 0; maxnorm=1)
40+
41+
# Non-positive m
42+
@test_throws ErrorException MIPSHash(; m = -1, maxnorm=1)
43+
@test_throws ErrorException MIPSHash(; m = 0, maxnorm=1)
44+
45+
# Non-positive scale factor
46+
@test_throws ErrorException MIPSHash(; scale = -1, maxnorm=1)
47+
@test_throws ErrorException MIPSHash(; scale = 0, maxnorm=1)
48+
49+
# maxnorm not specified or non-positive
50+
@test_throws ErrorException MIPSHash()
51+
@test_throws ErrorException MIPSHash(; maxnorm=-1)
52+
@test_throws ErrorException MIPSHash(; maxnorm=0)
4353
end
4454

4555
@testset "Hashing returns the correct data types" begin
46-
hashfn = MIPSHash{Float64}(; scale=1, m=3)
56+
hashfn = MIPSHash{Float64}(; maxnorm=20, scale=1, m=3)
4757

4858
# Matrix{Float64} -> Matrix{Int32}
4959
x = randn(4, 10)
@@ -66,14 +76,16 @@ Tests
6676

6777
@testset "MIPSHash h(P(x)) is correctly computed" begin
6878
n_hashes = 128
69-
scale = 0.5
70-
m = 3
71-
hashfn = MIPSHash(n_hashes; scale=scale, m=m)
79+
scale = 0.5
80+
m = 3
81+
x = randn(20)
82+
maxnorm = 2*norm(x)
83+
84+
hashfn = MIPSHash(n_hashes; maxnorm=maxnorm, scale=scale, m=m)
7285

7386
@test size(hashfn.coeff_B) == (n_hashes, 3)
7487
@test size(hashfn.shift) == (n_hashes,)
7588

76-
x = randn(20)
7789
hash = index_hash(hashfn, x)
7890

7991
@test isa(hash, Vector{Int32})
@@ -87,7 +99,7 @@ Tests
8799
### Compute hash manually
88100
# Start by performing the transform P(x)
89101
coeff = [hashfn.coeff_A hashfn.coeff_B]
90-
u = x / norm(x)
102+
u = x / maxnorm
91103
norm_powers = [norm(u)^2, norm(u)^4, norm(u)^8]
92104
Px = [u; norm_powers]
93105

@@ -100,14 +112,16 @@ Tests
100112

101113
@testset "MIPSHash h(Q(x)) is correctly computed" begin
102114
n_hashes = 128
103-
scale = 0.5
104-
m = 3
105-
hashfn = MIPSHash(n_hashes; scale=scale, m=m)
115+
scale = 0.5
116+
m = 3
117+
x = randn(20)
118+
maxnorm = 2*norm(x)
119+
120+
hashfn = MIPSHash(n_hashes; maxnorm=maxnorm, scale=scale, m=m)
106121

107122
@test size(hashfn.coeff_B) == (n_hashes, m)
108123
@test size(hashfn.shift) == (n_hashes,)
109124

110-
x = randn(40)
111125
hash = query_hash(hashfn, x)
112126

113127
@test isa(hash, Vector{Int32})
@@ -135,7 +149,7 @@ Tests
135149

136150
@testset "Hash inputs of different sizes" begin
137151
n_hashes = 16
138-
hashfn = MIPSHash(n_hashes)
152+
hashfn = MIPSHash(n_hashes; maxnorm=1000)
139153

140154
index_hash(hashfn, rand(10))
141155
@test size(hashfn.coeff_A) == (n_hashes, 10)
@@ -157,7 +171,7 @@ Tests
157171
end
158172

159173
@testset "resize_pow2 increases number of coefficients to powers of 2" begin
160-
hashfn = MIPSHash(10; resize_pow2=true)
174+
hashfn = MIPSHash(10; maxnorm=1000, resize_pow2=true)
161175
@test size(hashfn.coeff_A) == (10, 0)
162176

163177
index_hash(hashfn, rand(3))
@@ -174,45 +188,40 @@ Tests
174188
end
175189

176190
@testset "MIPSHash generates collisions for large inner products" begin
177-
n_hashes = 256
178-
scale = 1
179-
m = 5
180-
hashfn = MIPSHash(n_hashes; scale=scale, m=m)
181-
182-
x = randn(20)
183-
x_query_hashes = query_hash(hashfn, x)
184-
185-
# Check that MIPSHash isn't just generating a single query hash
186-
@test any(x_query_hashes .!= x_query_hashes[1])
191+
input_length = 5; n_hashes = 128;
187192

188-
# Compute the indexing hashes for a dataset with four vectors:
189-
# a) 10 * x (where x is the test query vector)
193+
# Compare a random vector x against four other vectors:
194+
# a) 10 * x
190195
# b) x
191196
# c) A vector of all zeros
192197
# d) -x
193-
dataset = [(10*x) x zero(x) -x]
198+
x = randn(input_length)
199+
x2, x3, x4 = 10*x, zero(x), -x
200+
201+
maxnorm = (x, x2, x3, x4) .|> norm |> maximum
202+
hashfn = MIPSHash(n_hashes; maxnorm=maxnorm)
203+
204+
x_query_hashes = query_hash(hashfn, x)
205+
206+
dataset = [x2 x x3 x4]
194207
p_hashes = index_hash(hashfn, dataset)
195208

196209
# Each collection of hashes should be different from one another
197210
@test let result = true
198-
for (ii,jj) in product(1:4, 1:4)
199-
if ii != jj && p_hashes[:,ii] == p_hashes[:,jj]
200-
result = false
201-
break
202-
end
203-
end
204-
result
211+
for (ii,jj) in Iterators.product(1:4, 1:4)
212+
if ii != jj && p_hashes[:,ii] == p_hashes[:,jj]
213+
result = false
214+
break
215+
end
216+
end
217+
result
205218
end
206-
207-
# The number of collisions should be highest for x and 2*x, second-highest
208-
# for x and x, second-lowest for x and zeros, and lowest for x and -x
209-
n_collisions = [sum(x_query_hashes .== p) for p in eachcol(p_hashes)]
210-
@test n_collisions[1] > n_collisions[2] > n_collisions[3] > n_collisions[4]
211219
end
212220

213221
@testset "Can compute hashes for sparse arrays" begin
214222
X = sprandn(Float32, 10, 1000, 0.2)
215-
hashfn = MIPSHash(8; scale=1, m=1)
223+
maxnorm = X |> eachcol .|> norm |> maximum
224+
hashfn = MIPSHash(8; maxnorm=maxnorm, scale=1, m=1)
216225

217226
ihashes = index_hash(hashfn, X)
218227
qhashes = query_hash(hashfn, X)

test/hashes/test_sign_alsh.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ Tests
110110
# The number of collisions should be highest for x and 2*x, second-highest
111111
# for x and x, second-lowest for x and zeros, and lowest for x and -x
112112
n_collisions = [sum(x_query_hashes .== p) for p in eachcol(p_hashes)]
113-
@test n_collisions[1] > n_collisions[2] > n_collisions[3] > n_collisions[4]
113+
@test n_collisions[1] > n_collisions[2] >
114+
n_collisions[3] > n_collisions[4]
114115
end
115116

116117
@testset "Can hash sparse arrays" begin

test/tables/test_table.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ using Test, Random, LSH
126126
n_inputs = 128
127127
n_hashes = 8
128128

129-
hashfn_mips = MIPSHash(n_hashes)
129+
hashfn_mips = MIPSHash(n_hashes; maxnorm=input_size)
130130
hashfn_sign = SignALSH(n_hashes; maxnorm=input_size)
131131

132132
for hashfn in (hashfn_mips, hashfn_sign)

0 commit comments

Comments
 (0)