Skip to content

Commit c6a5fb4

Browse files
committed
filter parametrized layers when checking input and output size of Flux.Chain model
1 parent 66956cc commit c6a5fb4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/predictive_model.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ When only one network is passed as a Chain object, input and output
6363
indexes are directly extracted.
6464
"""
6565
function PredictiveModel(network::Flux.Chain)
66-
input_size = size(network[1].weight)[2]
67-
output_size = size(network[end].weight)[1]
66+
param_layers = [layer for layer in network if has_params(layer)]
67+
input_size = size(param_layers[1].weight, 2)
68+
output_size = size(param_layers[end].weight, 1)
6869
input_output_map = [Dict(collect(1:input_size) => collect(1:output_size))]
6970
return PredictiveModel(
7071
[deepcopy(network)],

0 commit comments

Comments
 (0)