Skip to content

Commit a812927

Browse files
Fix TensorLayer initialization (#489)
Fixes #488
1 parent 536112c commit a812927

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/tensor_product_layer.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TensorLayer(model,out,p=nothing)
1111
Arguments:
1212
- `model`: Array of TensorProductBasis [B_1(n_1), ..., B_k(n_k)], where k corresponds to the dimension of the input.
1313
- `out`: Dimension of the output.
14-
- `p`: Optional initialization of the layer's weight. Initizalized to 0 by default.
14+
- `p`: Optional initialization of the layer's weight. Initialized to standard normal by default.
1515
"""
1616
struct TensorLayer{M<:Array{TensorProductBasis},P<:AbstractArray,Int} <: AbstractTensorProductLayer
1717
model::M
@@ -23,7 +23,9 @@ struct TensorLayer{M<:Array{TensorProductBasis},P<:AbstractArray,Int} <: Abstrac
2323
for basis in model
2424
number_of_weights *= basis.n
2525
end
26-
p = randn(out*number_of_weights)
26+
if p === nothing
27+
p = randn(out*number_of_weights)
28+
end
2729
new{Array{TensorProductBasis},typeof(p),Int}(model,p,length(model),out)
2830
end
2931
end

0 commit comments

Comments
 (0)