Skip to content

Commit 28ddb41

Browse files
committed
fix: fixed error with evaluation model only giving the first value
1 parent 766e1db commit 28ddb41

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

examples/dummy_example.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
int main(int argc,char *argv[]) {
55
init_julia(argc, argv);
66
return 0;
7-
load_model("/home/tristan/datasets/models/"
8-
"angular_dense_2Apf_epoch_10_16451353003083222301");
7+
load_model("tiny_angular_dense_3.0A_smooth_14_categorical_2024-08-02_epoch_70_18127875713564776610");
98
sphere data[2]= {{0.,0.,0.,1.},{1.,0.,0.,1.}};
109
load_atoms(data,2);
1110
point x[2] = {{0.,0.,1.},{1.,0.,0.}};

src/c_interface.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,23 @@ function evaluate_model(
117117
if distance(x, atoms.tree) >= cutoff_radius
118118
default_value
119119
else
120-
model((x, atoms)) |> cpu_device() |> first
120+
model((Batch(x), atoms)) |> cpu_device() |> first |> only
121121
end
122122
end
123-
124123
function evaluate_model(
125124
model::Lux.StatefulLuxLayer, x::Batch{Vector{Point3f}}, atoms::AnnotedKDTree;
126-
cutoff_radius, default_value = 0.0f0)
125+
cutoff_radius, default_value = 0.0f0)::Vector{Float32}
127126
is_close = map(x.field) do x
128-
distance(x, atoms.tree) <= cutoff_radius
127+
distance(x, atoms.tree) < cutoff_radius
129128
end
130129
close_points = x.field[is_close] |> Batch
131130
if length(close_points.field) > 0
132-
close_values = model((close_points, atoms)) |> cpu_device() |> first
133-
ifelse.(is_close, close_values, default_value)
131+
close_values = model((close_points, atoms)) |> cpu_device()
132+
res = fill(default_value,size(x.field)...)
133+
res[is_close] = close_values
134+
res
134135
else
135-
zeros(x.field)
136+
fill(default_value,size(x.field)...)
136137
end
137138
end
138139
Base.@ccallable function eval_model(points::Ptr{CPoint},length::Cint)::Float32

0 commit comments

Comments
 (0)