@@ -3,10 +3,15 @@ using Lux
3
3
function select_and_preprocess ((point, atoms); cutoff_radius)
4
4
select_and_preprocess (point, atoms; cutoff_radius)
5
5
end
6
- function select_and_preprocess (point:: Batch , atoms:: AnnotedKDTree{Sphere{T}} ; cutoff_radius) where T
6
+ function select_and_preprocess (
7
+ point:: Batch , atoms:: AnnotedKDTree{Sphere{T}} ; cutoff_radius) where {T}
7
8
neighboord = Folds. map (point. field) do point
8
- select_neighboord (point, atoms; cutoff_radius):: StructVector{Sphere{T}}
9
- end |> Batch{Vector{<: StructVector{Sphere{T}} }}
9
+ select_neighboord (point,
10
+ atoms;
11
+ cutoff_radius):: StructVector {
12
+ Sphere{T}, @NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} }, Int64}
13
+ end |> Batch{Vector{StructVector{
14
+ Sphere{T}, @NamedTuple {center:: Vector{Point3{T}} , r:: Vector{T} }, Int64}}}
10
15
preprocessing ((point, neighboord))
11
16
end
12
17
@@ -19,9 +24,9 @@ function evaluate_if_atoms_in_neighboord(layer, arg::AbstractArray, ps, st; zero
19
24
end
20
25
21
26
function general_angular_dense (main_chain, secondary_chain; name:: String ,
22
- van_der_waals_channel= false , on_gpu= true , cutoff_radius:: Float32 = 3.0f0 )
27
+ van_der_waals_channel = false , on_gpu = true , cutoff_radius:: Float32 = 3.0f0 )
23
28
main_chain = DeepSet (Chain (
24
- symetrise (; cutoff_radius, device= on_gpu ? gpu_device () : identity),
29
+ symetrise (; cutoff_radius, device = on_gpu ? gpu_device () : identity),
25
30
main_chain
26
31
))
27
32
function add_van_der_waals_channel (main_chain)
41
46
`tiny_angular_dense` is a function that generate a lux model.
42
47
43
48
"""
44
- function tiny_angular_dense (; van_der_waals_channel= false , kargs... )
49
+ function tiny_angular_dense (; van_der_waals_channel = false , kargs... )
45
50
general_angular_dense (
46
51
Parallel (.* ,
47
52
Chain (Dense (6 => 7 , elu),
@@ -52,12 +57,12 @@ function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
52
57
BatchNorm (4 + van_der_waals_channel),
53
58
Dense (4 + van_der_waals_channel => 6 , elu),
54
59
Dense (6 => 1 , sigmoid_fast));
55
- name= " tiny_angular_dense_" *
56
- (van_der_waals_channel ? " v" : " " ),
60
+ name = " tiny_angular_dense_" *
61
+ (van_der_waals_channel ? " v" : " " ),
57
62
van_der_waals_channel, kargs... )
58
63
end
59
64
60
- function light_angular_dense (; van_der_waals_channel= false , kargs... )
65
+ function light_angular_dense (; van_der_waals_channel = false , kargs... )
61
66
general_angular_dense (
62
67
Parallel (.* ,
63
68
Chain (Dense (6 => 10 , elu),
@@ -68,13 +73,13 @@ function light_angular_dense(; van_der_waals_channel=false, kargs...)
68
73
BatchNorm (5 + van_der_waals_channel),
69
74
Dense (5 + van_der_waals_channel => 10 , elu),
70
75
Dense (10 => 1 , sigmoid_fast));
71
- name= " light_angular_dense_" *
72
- (van_der_waals_channel ? " v" : " " ),
76
+ name = " light_angular_dense_" *
77
+ (van_der_waals_channel ? " v" : " " ),
73
78
van_der_waals_channel, kargs... )
74
79
end
75
80
76
81
function medium_angular_dense (;
77
- van_der_waals_channel= false , kargs... )
82
+ van_der_waals_channel = false , kargs... )
78
83
general_angular_dense (
79
84
Parallel (.* ,
80
85
Chain (Dense (6 => 15 , elu),
@@ -83,11 +88,11 @@ function medium_angular_dense(;
83
88
),
84
89
Chain (
85
90
BatchNorm (10 + van_der_waals_channel),
86
- Dense (10 + van_der_waals_channel => 5 ; use_bias= false ),
91
+ Dense (10 + van_der_waals_channel => 5 ; use_bias = false ),
87
92
Dense (5 => 10 , elu),
88
93
Dense (10 => 1 , sigmoid_fast));
89
- name= " medium_angular_dense_" *
90
- (van_der_waals_channel ? " v" : " " ),
94
+ name = " medium_angular_dense_" *
95
+ (van_der_waals_channel ? " v" : " " ),
91
96
van_der_waals_channel,
92
97
kargs... )
93
98
end
@@ -113,3 +118,7 @@ function get_cutoff_radius(x::Lux.AbstractExplicitLayer)
113
118
get_preprocessing (x). fun. kargs[:cutoff_radius ]
114
119
end
115
120
get_cutoff_radius (x:: Lux.StatefulLuxLayer ) = get_cutoff_radius (x. model)
121
+ function get_cutoff_radius (x:: Lux.AbstractExplicitLayer )
122
+ get_preprocessing (x). fun. kargs[:cutoff_radius ]
123
+ end
124
+ get_cutoff_radius (x:: Lux.StatefulLuxLayer ) = get_cutoff_radius (x. model)
0 commit comments