Skip to content

Commit 7115db9

Browse files
pritning fixes
1 parent f53a5f4 commit 7115db9

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/layers/basic.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ end
5252
# only slightly changed to better handle interaction with Zygote @dsweber2
5353
"""
5454
activations(c::Chain, input)
55+
5556
Calculate the forward results of each layers in Chain `c` with `input` as model input.
5657
"""
5758
function activations(c::Chain, input)
@@ -100,7 +101,7 @@ julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
100101
Dense(5, 2, tanh; bias=false)
101102
102103
julia> d1(ones(5))
103-
2-element Array{Float64,1}:
104+
2-element Vector{Float64}:
104105
0.9999092042625951
105106
0.9999092042625951
106107
@@ -384,6 +385,7 @@ Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l
384385
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
385386
386387
# Examples
388+
387389
```jldoctest
388390
julia> model = Chain(Dense(3, 5),
389391
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
@@ -394,6 +396,7 @@ julia> size(model(rand(3)))
394396
395397
julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
396398
Parallel(+, Dense(10, 2), Dense(5, 2))
399+
397400
julia> size(model(rand(10), rand(5)))
398401
(2,)
399402
```

src/utils.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,7 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
390390
end
391391
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
392392
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
393-
if eltype(bias) == eltype(weights)
394-
return bias
395-
else
396-
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
397-
return broadcast(eltype(weights), bias)
398-
end
393+
bias
399394
end
400395

401396
"""

0 commit comments

Comments
 (0)