Skip to content

Commit 0a106f6

Browse files
committed
fix: fix model
1 parent 4f55bd5 commit 0a106f6

File tree

4 files changed

+108
-28
lines changed

4 files changed

+108
-28
lines changed

Manifest.toml

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
6666
version = "1.1.1"
6767

6868
[[deps.ArrayInterface]]
69-
deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"]
70-
git-tree-sha1 = "5c9b74c973181571deb6442d41e5c902e6b9f38e"
69+
deps = ["Adapt", "LinearAlgebra"]
70+
git-tree-sha1 = "8c5b39db37c1d0340bf3b14895fba160c2d6cbb5"
7171
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
72-
version = "7.12.0"
72+
version = "7.14.0"
7373

7474
[deps.ArrayInterface.extensions]
7575
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
@@ -79,7 +79,8 @@ version = "7.12.0"
7979
ArrayInterfaceChainRulesExt = "ChainRules"
8080
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
8181
ArrayInterfaceReverseDiffExt = "ReverseDiff"
82-
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
82+
ArrayInterfaceSparseArraysExt = "SparseArrays"
83+
ArrayInterfaceStaticArraysExt = "StaticArrays"
8384
ArrayInterfaceTrackerExt = "Tracker"
8485

8586
[deps.ArrayInterface.weakdeps]
@@ -90,7 +91,8 @@ version = "7.12.0"
9091
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
9192
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
9293
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
93-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
94+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
95+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9496
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9597

9698
[[deps.Artifacts]]
@@ -146,9 +148,9 @@ version = "0.1.5"
146148

147149
[[deps.BioStructures]]
148150
deps = ["BioGenerics", "BioSymbols", "CodecZlib", "Downloads", "Format", "LinearAlgebra", "PrecompileTools", "RecipesBase", "Statistics"]
149-
git-tree-sha1 = "5bc9736b06063d04bc664a0e2445cbc3412012d0"
151+
git-tree-sha1 = "4326bae6741b7f929b67ea8a1987782001c8b1dd"
150152
uuid = "de9282ab-8554-53be-b2d6-f6c222edabfc"
151-
version = "4.0.0"
153+
version = "4.1.0"
152154

153155
[deps.BioStructures.extensions]
154156
BioStructuresBioAlignmentsExt = ["BioSequences", "BioAlignments"]
@@ -734,10 +736,10 @@ version = "0.11.1"
734736
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
735737

736738
[[deps.Lux]]
737-
deps = ["ADTypes", "Adapt", "ArgCheck", "ArrayInterface", "ChainRulesCore", "Compat", "ConcreteStructs", "ConstructionBase", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "GPUArraysCore", "LinearAlgebra", "LossFunctions", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "Optimisers", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "UnrolledUtilities", "WeightInitializers"]
738-
git-tree-sha1 = "cad82f8dc3239ea5b9eda9e918f11da98616ec8e"
739+
deps = ["ADTypes", "Adapt", "ArgCheck", "ArrayInterface", "ChainRulesCore", "Compat", "ConcreteStructs", "ConstructionBase", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "GPUArraysCore", "LinearAlgebra", "LossFunctions", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "NNlib", "Optimisers", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "UnrolledUtilities", "WeightInitializers"]
740+
git-tree-sha1 = "9b362dc50fc1ab1bb7222cfbabe104358d85b3f9"
739741
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
740-
version = "0.5.61"
742+
version = "0.5.62"
741743

742744
[deps.Lux.extensions]
743745
LuxComponentArraysExt = "ComponentArrays"
@@ -769,14 +771,15 @@ version = "0.5.61"
769771

770772
[[deps.LuxCore]]
771773
deps = ["Compat", "DispatchDoctor", "Functors", "Random", "Setfield"]
772-
git-tree-sha1 = "e25dd111094e3eefdcac5f094e7a8221b98ee44c"
774+
git-tree-sha1 = "ddb556f073f7acbafa1589075a87ef4f3d7e02d2"
773775
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
774-
version = "0.1.21"
775-
weakdeps = ["ChainRulesCore", "EnzymeCore"]
776+
version = "0.1.22"
777+
weakdeps = ["ChainRulesCore", "EnzymeCore", "MLDataDevices"]
776778

777779
[deps.LuxCore.extensions]
778780
LuxCoreChainRulesCoreExt = "ChainRulesCore"
779781
LuxCoreEnzymeCoreExt = "EnzymeCore"
782+
LuxCoreMLDataDevicesExt = "MLDataDevices"
780783

781784
[[deps.LuxDeviceUtils]]
782785
deps = ["Adapt", "ChainRulesCore", "Functors", "LuxCore", "Preferences", "Random", "UnrolledUtilities"]
@@ -813,10 +816,10 @@ version = "0.1.26"
813816
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
814817

815818
[[deps.LuxLib]]
816-
deps = ["ArrayInterface", "ChainRulesCore", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "KernelAbstractions", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "Markdown", "NNlib", "Random", "Reexport", "SLEEFPirates", "StaticArraysCore", "Statistics", "UnrolledUtilities"]
817-
git-tree-sha1 = "28c3e6ced4d45bb112e1322668f59e23523b784c"
819+
deps = ["ArrayInterface", "ChainRulesCore", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "KernelAbstractions", "LinearAlgebra", "LuxCore", "MLDataDevices", "Markdown", "NNlib", "Random", "Reexport", "SLEEFPirates", "StaticArraysCore", "Statistics", "UnrolledUtilities"]
820+
git-tree-sha1 = "7f5984033c9e41840111ec26b75bfb9c292db1d8"
818821
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
819-
version = "0.3.34"
822+
version = "0.3.37"
820823

821824
[deps.LuxLib.extensions]
822825
LuxLibCUDAExt = "CUDA"
@@ -832,6 +835,40 @@ version = "0.3.34"
832835
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
833836
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
834837

838+
[[deps.MLDataDevices]]
839+
deps = ["Adapt", "ChainRulesCore", "Functors", "Preferences", "Random", "UnrolledUtilities"]
840+
git-tree-sha1 = "7c8a26a11195c49062d11f3f0160335b10f07303"
841+
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
842+
version = "1.0.0"
843+
844+
[deps.MLDataDevices.extensions]
845+
MLDataDevicesAMDGPUExt = "AMDGPU"
846+
MLDataDevicesCUDAExt = "CUDA"
847+
MLDataDevicesFillArraysExt = "FillArrays"
848+
MLDataDevicesGPUArraysExt = "GPUArrays"
849+
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
850+
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
851+
MLDataDevicesReverseDiffExt = "ReverseDiff"
852+
MLDataDevicesSparseArraysExt = "SparseArrays"
853+
MLDataDevicesTrackerExt = "Tracker"
854+
MLDataDevicesZygoteExt = "Zygote"
855+
MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
856+
MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
857+
858+
[deps.MLDataDevices.weakdeps]
859+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
860+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
861+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
862+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
863+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
864+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
865+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
866+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
867+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
868+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
869+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
870+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
871+
835872
[[deps.MacroTools]]
836873
deps = ["Markdown", "Random"]
837874
git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df"
@@ -1321,9 +1358,9 @@ weakdeps = ["Requires", "StaticArraysCore"]
13211358

13221359
[[deps.WeightInitializers]]
13231360
deps = ["ArgCheck", "ChainRulesCore", "ConcreteStructs", "GPUArraysCore", "LinearAlgebra", "Random", "SpecialFunctions", "Statistics"]
1324-
git-tree-sha1 = "2bf1485c270ac087ad707df929ce00de19a885ee"
1361+
git-tree-sha1 = "58d74e2f95c825935b8307360e66b73359db47fd"
13251362
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1326-
version = "0.1.10"
1363+
version = "1.0.0"
13271364

13281365
[deps.WeightInitializers.extensions]
13291366
WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"]

src/MLNanoShaperRunner.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ typedef struct {
33
float y;
44
float z;
55
float r;
6-
} Sphere;
6+
} sphere;
77

8-
// void init_julia(int argc, char *argv[]);
8+
typedef struct {
9+
float x;
10+
float y;
11+
float z;
12+
} point;
913
// void shutdown_julia(int retcode);
1014

1115
/*
@@ -31,10 +35,10 @@ length of the array
3135
- 1: data could not be read
3236
- 2: unknow error
3337
*/
34-
int load_atoms(Sphere *start, int length);
38+
int load_atoms(sphere *start, int length);
3539
/*
3640
eval_model(x::Float32,y::Float32,z::Float32)::Float32
3741
3842
evaluate the model at coordinates `x` `y` `z`.
3943
*/
40-
float eval_model(float x, float y, float z);
44+
float eval_model(point *start,int length);

src/c_interface.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ mutable struct CSphere
2828
z::Float32
2929
r::Float32
3030
end
31+
mutable struct CPoint
32+
x::Float32
33+
y::Float32
34+
z::Float32
35+
end
3136
CSphere((; center, r)::Sphere) = CSphere(center..., r)
3237

3338
"""
@@ -101,11 +106,40 @@ Base.@ccallable function load_atoms(start::Ptr{CSphere}, length::Cint)::Cint
101106
end
102107
end
103108

104-
Base.@ccallable function eval_model(x::Float32, y::Float32, z::Float32)::Float32
105-
point = Point3f(x, y, z)
106-
if distance(point, global_state.atoms.tree) >= global_state.cutoff_radius
107-
0.0f0
109+
"""
110+
evaluate_model(
111+
model::Lux.StatefulLuxLayer, x::Point3f, atoms::AnnotedKDTree; cutoff_radius, default_value = -0.0f0)
112+
113+
evaluate the model on a single point.
114+
This function handle the logic in case the point is too far from the atoms. In this case default_value is returned and the model is not run.
115+
"""
116+
function evaluate_model(
117+
model::Lux.StatefulLuxLayer, x::Point3f, atoms::AnnotedKDTree; cutoff_radius, default_value = -0.0f0)
118+
if distance(x, atoms.tree) >= cutoff_radius
119+
default_value
120+
else
121+
model((x, atoms)) |> cpu_device() |> first
122+
end
123+
end
124+
125+
function evaluate_model(
126+
model::Lux.StatefulLuxLayer, x::Batch{Vector{Point3f}}, atoms::AnnotedKDTree;
127+
cutoff_radius, default_value = 0.0f0)
128+
is_close = map(x.field) do x
129+
distance(x, atoms.tree) <= cutoff_radius
130+
end
131+
close_points = x.field[is_close] |> Batch
132+
if length(close_points.field) > 0
133+
close_values = model((close_points, atoms)) |> cpu_device() |> first
134+
ifelse.(is_close, close_values, default_value)
108135
else
109-
only(model((Point3f(x), global_state.atoms))) - 0.5f0
136+
zeros(x.field)
110137
end
111138
end
139+
Base.@ccallable function eval_model(points::Ptr{CPoints},length::Cint)::Float32
140+
points = map(unsafe_wrap(Array, points, length)) do p
141+
Point3f(p.x,p.y,p.z)
142+
end |> Batch
143+
evaluate_model(global_state.model,points,global_state.atoms;cutoff_radius = global_state.cutoff_radius)
144+
145+
end

src/distance_tree.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ end
2424

2525
distance(x::AbstractVector, y::KDTree)::Number = nn(y, x) |> last
2626

27-
function signed_distance(p::Point3, mesh::RegionMesh)::Number
27+
"""
28+
signed_distance(p::Point3, mesh::RegionMesh)::Number
29+
30+
returns the signed distance between point p and the mesh
31+
"""
32+
function signed_distance(p::Point3{T}, mesh::RegionMesh)::T where T<:Number
2833
id_point, dist = nn(mesh.tree, p)
2934
x, y, z = mesh.triangles[OffsetInteger{-1, UInt32}(id_point)]
3035
# @info "triangle" x y z

0 commit comments

Comments
 (0)