Skip to content

Commit 3c4875e

Browse files
author
Michael Abbott
committed
use inner constructor for Bilinear, more like Dense
1 parent 644ec8c commit 3c4875e

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/layers/basic.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ end
287287
Bilinear(W::AbstractArray, [bias, σ])
288288
289289
Creates a Bilinear layer, which operates on two inputs at the same time.
290-
It its output, given vectors `x`, `y` is another vector `z` with,
290+
Its output, given vectors `x` & `y`, is another vector `z` with,
291291
for all `i ∈ 1:out`:
292292
293293
z[i] = σ(x' * W[i,:,:] * y + bias[i])
@@ -323,23 +323,27 @@ julia> sc = SkipConnection(
323323
324324
julia> sc(x) |> size
325325
(3, 32)
326+
327+
julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the output
328+
Bilinear(8, 16, 4, tanh, bias=false)
326329
```
327330
"""
328-
struct Bilinear{A,B,S}
331+
struct Bilinear{F,A,B}
329332
weight::A
330333
bias::B
331-
σ::S
334+
σ::F
335+
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
336+
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
337+
b = create_bias(W, bias, size(W,1))
338+
new{F,A,typeof(b)}(W, b, σ)
339+
end
332340
end
333341

334342
@functor Bilinear
335343

336-
Bilinear(weight::AbstractArray, bias = true) = Bilinear(weight, create_bias(weight, bias, size(weight,1)), identity)
337-
338344
function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
339-
init = glorot_uniform, bias = true)
340-
W = init(out, in1, in2)
341-
b = create_bias(W, bias, out)
342-
return Bilinear(W, b, σ)
345+
init = glorot_uniform, bias = true)
346+
Bilinear(init(out, in1, in2), bias, σ)
343347
end
344348

345349
function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)

test/layers/basic.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,17 @@ import Flux: activations
161161
b2 = Flux.Bilinear(randn(3,4,5), false)
162162
@test b2.bias == Flux.Zeros()
163163

164-
b3 = Flux.Bilinear(randn(3,4,5), true, tanh)
164+
b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh)
165165
@test b3.σ == tanh
166+
@test b2.bias isa Vector{Float16}
166167
@test size(b3(rand(4), rand(5))) == (3,)
167168

168169
b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
169170
@test b4.bias isa Vector{Float32}
171+
172+
@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array
173+
@test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh)
174+
@test_throws DimensionMismatch Flux.Bilinear(rand(3,4,5), rand(6), tanh) # wrong length bias
170175
end
171176
end
172177

0 commit comments

Comments
 (0)