@@ -4,14 +4,14 @@ function select_and_preprocess((point, atoms); cutoff_radius)
4
4
select_and_preprocess (point, atoms; cutoff_radius)
5
5
end
6
6
function select_and_preprocess (
7
- point:: Batch , atoms:: AnnotedKDTree{Sphere{T}} ; cutoff_radius) where {T}
7
+ point:: Batch , atoms:: AnnotedKDTree{Sphere{T}} ; cutoff_radius) where {T}
8
8
neighboord = Folds. map (point. field) do point
9
9
select_neighboord (point,
10
10
atoms;
11
11
cutoff_radius):: StructVector {
12
- Sphere{T}, @NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} }, Int64}
12
+ Sphere{T},@NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} },Int64}
13
13
end |> Batch{Vector{StructVector{
14
- Sphere{T}, @NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} }, Int64}}}
14
+ Sphere{T},@NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} },Int64}}}
15
15
preprocessing ((point, neighboord))
16
16
end
17
17
@@ -24,9 +24,9 @@ function evaluate_if_atoms_in_neighboord(layer, arg::AbstractArray, ps, st; zero
24
24
end
25
25
26
26
function general_angular_dense (main_chain, secondary_chain; name:: String ,
27
- van_der_waals_channel = false , on_gpu = true , cutoff_radius:: Float32 = 3.0f0 )
27
+ van_der_waals_channel= false , on_gpu= true , cutoff_radius:: Float32 = 3.0f0 )
28
28
main_chain = DeepSet (Chain (
29
- symetrise (; cutoff_radius, device = on_gpu ? gpu_device () : identity),
29
+ symetrise (; cutoff_radius, device= on_gpu ? gpu_device () : identity),
30
30
main_chain
31
31
))
32
32
function add_van_der_waals_channel (main_chain)
46
46
`tiny_angular_dense` is a function that generate a lux model.
47
47
48
48
"""
49
- function tiny_angular_dense (; van_der_waals_channel = false , kargs... )
49
+ function tiny_angular_dense (; van_der_waals_channel= false , kargs... )
50
50
general_angular_dense (
51
51
Parallel (.* ,
52
52
Chain (Dense (6 => 7 , elu),
@@ -57,12 +57,12 @@ function tiny_angular_dense(; van_der_waals_channel = false, kargs...)
57
57
BatchNorm (4 + van_der_waals_channel),
58
58
Dense (4 + van_der_waals_channel => 6 , elu),
59
59
Dense (6 => 1 , sigmoid_fast));
60
- name = " tiny_angular_dense_" *
61
- (van_der_waals_channel ? " v" : " " ),
60
+ name= " tiny_angular_dense_" *
61
+ (van_der_waals_channel ? " v" : " " ),
62
62
van_der_waals_channel, kargs... )
63
63
end
64
64
65
- function light_angular_dense (; van_der_waals_channel = false , kargs... )
65
+ function light_angular_dense (; van_der_waals_channel= false , kargs... )
66
66
general_angular_dense (
67
67
Parallel (.* ,
68
68
Chain (Dense (6 => 10 , elu),
@@ -73,13 +73,13 @@ function light_angular_dense(; van_der_waals_channel = false, kargs...)
73
73
BatchNorm (5 + van_der_waals_channel),
74
74
Dense (5 + van_der_waals_channel => 10 , elu),
75
75
Dense (10 => 1 , sigmoid_fast));
76
- name = " light_angular_dense_" *
77
- (van_der_waals_channel ? " v" : " " ),
76
+ name= " light_angular_dense_" *
77
+ (van_der_waals_channel ? " v" : " " ),
78
78
van_der_waals_channel, kargs... )
79
79
end
80
80
81
81
function medium_angular_dense (;
82
- van_der_waals_channel = false , kargs... )
82
+ van_der_waals_channel= false , kargs... )
83
83
general_angular_dense (
84
84
Parallel (.* ,
85
85
Chain (Dense (6 => 15 , elu),
@@ -88,17 +88,17 @@ function medium_angular_dense(;
88
88
),
89
89
Chain (
90
90
BatchNorm (10 + van_der_waals_channel),
91
- Dense (10 + van_der_waals_channel => 5 ; use_bias = false ),
91
+ Dense (10 + van_der_waals_channel => 5 ; use_bias= false ),
92
92
Dense (5 => 10 , elu),
93
93
Dense (10 => 1 , sigmoid_fast));
94
- name = " medium_angular_dense_" *
95
- (van_der_waals_channel ? " v" : " " ),
94
+ name= " medium_angular_dense_" *
95
+ (van_der_waals_channel ? " v" : " " ),
96
96
van_der_waals_channel,
97
97
kargs... )
98
98
end
99
99
function drop_preprocessing (x:: Chain )
100
100
if typeof (x[1 ]) <: PreprocessingLayer
101
- Chain (NoOpLayer (), map (i -> x[i], 2 : length (x))... , disable_optimizations = true )
101
+ Chain (NoOpLayer (), map (i -> x[i], 2 : length (x))... , disable_optimizations= true )
102
102
else
103
103
x
104
104
end
@@ -115,6 +115,11 @@ struct SerializedModel
115
115
model:: Partial
116
116
weights:: NamedTuple
117
117
end
118
+ production_instantiate ((; model, weights):: SerializedModel ) = Lux. StatefulLuxLayer (
119
+ model (),
120
+ weights,
121
+ Lux. initialparameters (MersenneTwister (42 ), model. model ()) |> Lux. testmode)
122
+
118
123
function get_cutoff_radius (x:: Lux.AbstractExplicitLayer )
119
124
get_preprocessing (x). fun. kargs[:cutoff_radius ]
120
125
end
0 commit comments