Skip to content

Commit 52f2ffc

Browse files
committed
fix: made models continuous
1 parent 96fad0c commit 52f2ffc

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

src/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ function symetrise(val::StructArray{PreprocessedData{T}};
141141
vcat(dot,
142142
r_1 .+ r_2,
143143
abs.(r_1 .- r_2),
144-
d_1 .+ d_2, abs.(d_1 .- d_2)) .*
145-
cut.(cutoff_radius, r_1) .* cut.(cutoff_radius, r_2)
144+
d_1 .+ d_2, abs.(d_1 .- d_2),
145+
cut.(cutoff_radius, r_1) .* cut.(cutoff_radius, r_2))
146146
end
147147
scale_factor(x) = x[end:end, :]
148148

src/models.jl

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function evaluate_if_atoms_in_neighboord(layer, arg::AbstractArray, ps, st; zero
1919
end
2020

2121
function general_angular_dense(main_chain, secondary_chain; name::String,
22-
van_der_waals_channel = false, on_gpu = true, cutoff_radius::Float32 = 3.0f0)
22+
van_der_waals_channel=false, on_gpu=true, cutoff_radius::Float32=3.0f0)
2323
main_chain = DeepSet(Chain(
24-
symetrise(; cutoff_radius, device = on_gpu ? gpu_device() : identity),
24+
symetrise(; cutoff_radius, device=on_gpu ? gpu_device() : identity),
2525
main_chain
2626
))
2727
function add_van_der_waals_channel(main_chain)
@@ -41,44 +41,53 @@ end
4141
`tiny_angular_dense` is a function that generate a lux model.
4242
4343
"""
44-
function tiny_angular_dense(; van_der_waals_channel = false, kargs...)
44+
function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
4545
general_angular_dense(
46-
Chain(Dense(5 => 7, elu),
47-
Dense(7 => 4, elu)),
46+
Parallel(.*,
47+
Chain(Dense(6 => 7, elu),
48+
Dense(7 => 4, elu)),
49+
Lux.WrappedFunction(scale_factor)
50+
),
4851
Chain(
4952
BatchNorm(4 + van_der_waals_channel),
5053
Dense(4 + van_der_waals_channel => 6, elu),
5154
Dense(6 => 1, sigmoid_fast));
52-
name = "tiny_angular_dense_" *
53-
(van_der_waals_channel ? "v" : ""),
55+
name="tiny_angular_dense_" *
56+
(van_der_waals_channel ? "v" : ""),
5457
van_der_waals_channel, kargs...)
5558
end
5659

57-
function light_angular_dense(; van_der_waals_channel = false, kargs...)
60+
function light_angular_dense(; van_der_waals_channel=false, kargs...)
5861
general_angular_dense(
59-
Chain(Dense(5 => 10, elu),
60-
Dense(10 => 5, elu)),
62+
Parallel(.*,
63+
Chain(Dense(6 => 10, elu),
64+
Dense(10 => 5, elu)),
65+
Lux.WrappedFunction(scale_factor)
66+
),
6167
Chain(
6268
BatchNorm(5 + van_der_waals_channel),
6369
Dense(5 + van_der_waals_channel => 10, elu),
6470
Dense(10 => 1, sigmoid_fast));
65-
name = "light_angular_dense_" *
66-
(van_der_waals_channel ? "v" : ""),
71+
name="light_angular_dense_" *
72+
(van_der_waals_channel ? "v" : ""),
6773
van_der_waals_channel, kargs...)
6874
end
6975

7076
function medium_angular_dense(;
71-
van_der_waals_channel = false, kargs...)
72-
general_angular_dense(Chain(
73-
Dense(5 => 15, elu),
74-
Dense(15 => 10, elu)),
77+
van_der_waals_channel=false, kargs...)
78+
general_angular_dense(
79+
Parallel(.*,
80+
Chain(Dense(6 => 15, elu),
81+
Dense(15 => 5, elu)),
82+
Lux.WrappedFunction(scale_factor)
83+
),
7584
Chain(
7685
BatchNorm(10 + van_der_waals_channel),
77-
Dense(10 + van_der_waals_channel => 5; use_bias = false),
86+
Dense(10 + van_der_waals_channel => 5; use_bias=false),
7887
Dense(5 => 10, elu),
7988
Dense(10 => 1, sigmoid_fast));
80-
name = "medium_angular_dense_" *
81-
(van_der_waals_channel ? "v" : ""),
89+
name="medium_angular_dense_" *
90+
(van_der_waals_channel ? "v" : ""),
8291
van_der_waals_channel,
8392
kargs...)
8493
end

0 commit comments

Comments
 (0)