Skip to content

Commit 257fd74

Browse files
committed
fix: fixed error caused by Lux update
1 parent 0be8980 commit 257fd74

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
scale_factor(x) = x[end:end, :]
148148

149149
function symetrise(; cutoff_radius::Number, device)
150-
Partial(symetrise; cutoff_radius, device) |> Lux.WrappedFunction
150+
Partial(symetrise; cutoff_radius, device) |> Lux.WrappedFunction{:direct_call}
151151
end
152152

153153
function trace(message::String, x)

src/models.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function general_angular_dense(main_chain, secondary_chain; name::String,
3232
function add_van_der_waals_channel(main_chain)
3333
Parallel(vcat,
3434
main_chain,
35-
WrappedFunction{:direct_call}((x -> Float32.(x)) is_in_van_der_waals))
35+
WrappedFunction{:direct_call}((x -> Float32.(x)) is_in_van_der_waals))
3636
end
3737
Chain(PreprocessingLayer(Partial(select_and_preprocess; cutoff_radius)),
3838
main_chain |> (van_der_waals_channel ? add_van_der_waals_channel : identity),
@@ -96,12 +96,13 @@ function medium_angular_dense(;
9696
van_der_waals_channel,
9797
kargs...)
9898
end
99-
drop_preprocessing(x::Chain) =
99+
function drop_preprocessing(x::Chain)
100100
if typeof(x[1]) <: PreprocessingLayer
101-
Chain(NoOpLayer(), x[2:end])
101+
Chain(NoOpLayer(), map(i -> x[i], 2:length(x))..., disable_optimizations = true)
102102
else
103103
x
104104
end
105+
end
105106

106107
get_preprocessing(x::Chain) =
107108
if typeof(x[1]) <: PreprocessingLayer

0 commit comments

Comments
 (0)