Skip to content

Commit bb309df

Browse files
committed
fix: weird type error
1 parent b754c50 commit bb309df

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

src/models.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@ using Lux
33
function select_and_preprocess((point, atoms); cutoff_radius)
44
select_and_preprocess(point, atoms; cutoff_radius)
55
end
6-
function select_and_preprocess(point::Batch, atoms::AnnotedKDTree{Sphere{T}}; cutoff_radius) where T
6+
function select_and_preprocess(
7+
point::Batch, atoms::AnnotedKDTree{Sphere{T}}; cutoff_radius) where {T}
78
neighboord = Folds.map(point.field) do point
8-
select_neighboord(point, atoms; cutoff_radius)::StructVector{Sphere{T}}
9-
end |> Batch{Vector{<:StructVector{Sphere{T}}}}
9+
select_neighboord(point,
10+
atoms;
11+
cutoff_radius)::StructVector{
12+
Sphere{T}, @NamedTuple{center::Vector{Point3{T}}, r::Vector{T}}, Int64}
13+
end |> Batch{Vector{StructVector{
14+
Sphere{T}, @NamedTuple{center::Vector{Point3{T}}, r::Vector{T}}, Int64}}}
1015
preprocessing((point, neighboord))
1116
end
1217

@@ -19,9 +24,9 @@ function evaluate_if_atoms_in_neighboord(layer, arg::AbstractArray, ps, st; zero
1924
end
2025

2126
function general_angular_dense(main_chain, secondary_chain; name::String,
22-
van_der_waals_channel=false, on_gpu=true, cutoff_radius::Float32=3.0f0)
27+
van_der_waals_channel = false, on_gpu = true, cutoff_radius::Float32 = 3.0f0)
2328
main_chain = DeepSet(Chain(
24-
symetrise(; cutoff_radius, device=on_gpu ? gpu_device() : identity),
29+
symetrise(; cutoff_radius, device = on_gpu ? gpu_device() : identity),
2530
main_chain
2631
))
2732
function add_van_der_waals_channel(main_chain)
@@ -41,7 +46,7 @@ end
4146
`tiny_angular_dense` is a function that generate a lux model.
4247
4348
"""
44-
function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
49+
function tiny_angular_dense(; van_der_waals_channel = false, kargs...)
4550
general_angular_dense(
4651
Parallel(.*,
4752
Chain(Dense(6 => 7, elu),
@@ -52,12 +57,12 @@ function tiny_angular_dense(; van_der_waals_channel=false, kargs...)
5257
BatchNorm(4 + van_der_waals_channel),
5358
Dense(4 + van_der_waals_channel => 6, elu),
5459
Dense(6 => 1, sigmoid_fast));
55-
name="tiny_angular_dense_" *
56-
(van_der_waals_channel ? "v" : ""),
60+
name = "tiny_angular_dense_" *
61+
(van_der_waals_channel ? "v" : ""),
5762
van_der_waals_channel, kargs...)
5863
end
5964

60-
function light_angular_dense(; van_der_waals_channel=false, kargs...)
65+
function light_angular_dense(; van_der_waals_channel = false, kargs...)
6166
general_angular_dense(
6267
Parallel(.*,
6368
Chain(Dense(6 => 10, elu),
@@ -68,13 +73,13 @@ function light_angular_dense(; van_der_waals_channel=false, kargs...)
6873
BatchNorm(5 + van_der_waals_channel),
6974
Dense(5 + van_der_waals_channel => 10, elu),
7075
Dense(10 => 1, sigmoid_fast));
71-
name="light_angular_dense_" *
72-
(van_der_waals_channel ? "v" : ""),
76+
name = "light_angular_dense_" *
77+
(van_der_waals_channel ? "v" : ""),
7378
van_der_waals_channel, kargs...)
7479
end
7580

7681
function medium_angular_dense(;
77-
van_der_waals_channel=false, kargs...)
82+
van_der_waals_channel = false, kargs...)
7883
general_angular_dense(
7984
Parallel(.*,
8085
Chain(Dense(6 => 15, elu),
@@ -83,11 +88,11 @@ function medium_angular_dense(;
8388
),
8489
Chain(
8590
BatchNorm(10 + van_der_waals_channel),
86-
Dense(10 + van_der_waals_channel => 5; use_bias=false),
91+
Dense(10 + van_der_waals_channel => 5; use_bias = false),
8792
Dense(5 => 10, elu),
8893
Dense(10 => 1, sigmoid_fast));
89-
name="medium_angular_dense_" *
90-
(van_der_waals_channel ? "v" : ""),
94+
name = "medium_angular_dense_" *
95+
(van_der_waals_channel ? "v" : ""),
9196
van_der_waals_channel,
9297
kargs...)
9398
end
@@ -113,3 +118,7 @@ function get_cutoff_radius(x::Lux.AbstractExplicitLayer)
113118
get_preprocessing(x).fun.kargs[:cutoff_radius]
114119
end
115120
get_cutoff_radius(x::Lux.StatefulLuxLayer) = get_cutoff_radius(x.model)
121+
function get_cutoff_radius(x::Lux.AbstractExplicitLayer)
122+
get_preprocessing(x).fun.kargs[:cutoff_radius]
123+
end
124+
get_cutoff_radius(x::Lux.StatefulLuxLayer) = get_cutoff_radius(x.model)

0 commit comments

Comments
 (0)