Skip to content

Commit ecfa3ba

Browse files
committed
fix: added function production_instantiate
1 parent 257fd74 commit ecfa3ba

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/models.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ function select_and_preprocess((point, atoms); cutoff_radius)
44
select_and_preprocess(point, atoms; cutoff_radius)
55
end
66
function select_and_preprocess(
7-
point::Batch, atoms::AnnotedKDTree{Sphere{T}}; cutoff_radius) where {T}
7+
point::Batch, atoms::AnnotedKDTree{Sphere{T}}; cutoff_radius) where {T}
88
neighboord = Folds.map(point.field) do point
99
select_neighboord(point,
1010
atoms;
1111
cutoff_radius)::StructVector{
12-
Sphere{T}, @NamedTuple{center::Vector{Point3{T}}, r::Vector{T}}, Int64}
12+
Sphere{T},@NamedTuple{center::Vector{Point3{T}}, r::Vector{T}},Int64}
1313
end |> Batch{Vector{StructVector{
14-
Sphere{T}, @NamedTuple{center::Vector{Point3{T}}, r::Vector{T}}, Int64}}}
14+
Sphere{T},@NamedTuple{center::Vector{Point3{T}}, r::Vector{T}},Int64}}}
1515
preprocessing((point, neighboord))
1616
end
1717

@@ -24,9 +24,9 @@ function evaluate_if_atoms_in_neighboord(layer, arg::AbstractArray, ps, st; zero
2424
end
2525

2626
function general_angular_dense(main_chain, secondary_chain; name::String,
27-
van_der_waals_channel = false, on_gpu = true, cutoff_radius::Float32 = 3.0f0)
27+
van_der_waals_channel=false, on_gpu=true, cutoff_radius::Float32=3.0f0)
2828
main_chain = DeepSet(Chain(
29-
symetrise(; cutoff_radius, device = on_gpu ? gpu_device() : identity),
29+
symetrise(; cutoff_radius, device=on_gpu ? gpu_device() : identity),
3030
main_chain
3131
))
3232
function add_van_der_waals_channel(main_chain)
@@ -46,7 +46,7 @@ end
4646
`tiny_angular_dense` is a function that generate a lux model.
4747
4848
"""
49-
function tiny_angular_dense(; van_der_waals_channel = false, kargs...)
49+
function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
5050
general_angular_dense(
5151
Parallel(.*,
5252
Chain(Dense(6 => 7, elu),
@@ -57,12 +57,12 @@ function tiny_angular_dense(; van_der_waals_channel = false, kargs...)
5757
BatchNorm(4 + van_der_waals_channel),
5858
Dense(4 + van_der_waals_channel => 6, elu),
5959
Dense(6 => 1, sigmoid_fast));
60-
name = "tiny_angular_dense_" *
61-
(van_der_waals_channel ? "v" : ""),
60+
name="tiny_angular_dense_" *
61+
(van_der_waals_channel ? "v" : ""),
6262
van_der_waals_channel, kargs...)
6363
end
6464

65-
function light_angular_dense(; van_der_waals_channel = false, kargs...)
65+
function light_angular_dense(; van_der_waals_channel=false, kargs...)
6666
general_angular_dense(
6767
Parallel(.*,
6868
Chain(Dense(6 => 10, elu),
@@ -73,13 +73,13 @@ function light_angular_dense(; van_der_waals_channel = false, kargs...)
7373
BatchNorm(5 + van_der_waals_channel),
7474
Dense(5 + van_der_waals_channel => 10, elu),
7575
Dense(10 => 1, sigmoid_fast));
76-
name = "light_angular_dense_" *
77-
(van_der_waals_channel ? "v" : ""),
76+
name="light_angular_dense_" *
77+
(van_der_waals_channel ? "v" : ""),
7878
van_der_waals_channel, kargs...)
7979
end
8080

8181
function medium_angular_dense(;
82-
van_der_waals_channel = false, kargs...)
82+
van_der_waals_channel=false, kargs...)
8383
general_angular_dense(
8484
Parallel(.*,
8585
Chain(Dense(6 => 15, elu),
@@ -88,17 +88,17 @@ function medium_angular_dense(;
8888
),
8989
Chain(
9090
BatchNorm(10 + van_der_waals_channel),
91-
Dense(10 + van_der_waals_channel => 5; use_bias = false),
91+
Dense(10 + van_der_waals_channel => 5; use_bias=false),
9292
Dense(5 => 10, elu),
9393
Dense(10 => 1, sigmoid_fast));
94-
name = "medium_angular_dense_" *
95-
(van_der_waals_channel ? "v" : ""),
94+
name="medium_angular_dense_" *
95+
(van_der_waals_channel ? "v" : ""),
9696
van_der_waals_channel,
9797
kargs...)
9898
end
9999
function drop_preprocessing(x::Chain)
100100
if typeof(x[1]) <: PreprocessingLayer
101-
Chain(NoOpLayer(), map(i -> x[i], 2:length(x))..., disable_optimizations = true)
101+
Chain(NoOpLayer(), map(i -> x[i], 2:length(x))..., disable_optimizations=true)
102102
else
103103
x
104104
end
@@ -115,6 +115,11 @@ struct SerializedModel
115115
model::Partial
116116
weights::NamedTuple
117117
end
118+
production_instantiate((; model, weights)::SerializedModel) = Lux.StatefulLuxLayer(
119+
model(),
120+
weights,
121+
Lux.initialparameters(MersenneTwister(42), model.model()) |> Lux.testmode)
122+
118123
function get_cutoff_radius(x::Lux.AbstractExplicitLayer)
119124
get_preprocessing(x).fun.kargs[:cutoff_radius]
120125
end

0 commit comments

Comments
 (0)