Skip to content

Commit 766e1db

Browse files
committed
fix: fixed c_interface
1 parent ecfa3ba commit 766e1db

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

src/c_interface.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ Base.@ccallable function load_model(path::Cstring)::Cint
7474
data = deserialize(path)
7575
@debug "deserialized"
7676
if typeof(data) <: SerializedModel
77-
global_state.model = StatefulLuxLayer(data.model(), data.parameters,
78-
Lux.initialstates(MersenneTwister(42), data.model()))
77+
global_state.model = production_instantiate(data)
7978
global_state.cutoff_radius = get_cutoff_radius(global_state.model)
8079
0
8180
else

src/models.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
5757
BatchNorm(4 + van_der_waals_channel),
5858
Dense(4 + van_der_waals_channel => 6, elu),
5959
Dense(6 => 1, sigmoid_fast));
60-
name="tiny_angular_dense_" *
61-
(van_der_waals_channel ? "v" : ""),
60+
name="tiny_angular_dense" *
61+
(van_der_waals_channel ? "_v" : ""),
6262
van_der_waals_channel, kargs...)
6363
end
6464

@@ -73,8 +73,8 @@ function light_angular_dense(; van_der_waals_channel=false, kargs...)
7373
BatchNorm(5 + van_der_waals_channel),
7474
Dense(5 + van_der_waals_channel => 10, elu),
7575
Dense(10 => 1, sigmoid_fast));
76-
name="light_angular_dense_" *
77-
(van_der_waals_channel ? "v" : ""),
76+
name="light_angular_dense" *
77+
(van_der_waals_channel ? "_v" : ""),
7878
van_der_waals_channel, kargs...)
7979
end
8080

@@ -91,8 +91,8 @@ function medium_angular_dense(;
9191
Dense(10 + van_der_waals_channel => 5; use_bias=false),
9292
Dense(5 => 10, elu),
9393
Dense(10 => 1, sigmoid_fast));
94-
name="medium_angular_dense_" *
95-
(van_der_waals_channel ? "v" : ""),
94+
name="medium_angular_dense" *
95+
(van_der_waals_channel ? "_v" : ""),
9696
van_der_waals_channel,
9797
kargs...)
9898
end
@@ -113,12 +113,14 @@ get_preprocessing(x::Chain) =
113113

114114
struct SerializedModel
115115
model::Partial
116-
weights::NamedTuple
116+
parameters::NamedTuple
117+
states::NamedTuple
117118
end
118-
production_instantiate((; model, weights)::SerializedModel) = Lux.StatefulLuxLayer(
119+
120+
production_instantiate((; model, parameters,states)::SerializedModel) = Lux.StatefulLuxLayer(
119121
model(),
120-
weights,
121-
Lux.initialparameters(MersenneTwister(42), model.model()) |> Lux.testmode)
122+
parameters,
123+
states |> Lux.testmode)
122124

123125
function get_cutoff_radius(x::Lux.AbstractExplicitLayer)
124126
get_preprocessing(x).fun.kargs[:cutoff_radius]

0 commit comments

Comments
 (0)