@@ -5,20 +5,21 @@ use crate::{
55 } ,
66 reconstruction:: { SurfaceReconstructionF32 , SurfaceReconstructionF64 } ,
77} ;
8- use numpy:: { Element , PyArray1 , PyArray2 , PyReadonlyArray1 , PyReadonlyArray2 } ;
8+ use numpy:: { Element , PyArray1 , PyArray2 , PyArrayMethods , PyReadonlyArray1 , PyReadonlyArray2 } ;
99use pyo3:: {
1010 prelude:: * ,
1111 types:: { PyDict , PyString } ,
1212} ;
1313use splashsurf_lib:: {
1414 Aabb3d , GridDecompositionParameters , Index , Real , SpatialDecomposition ,
15- mesh:: { OwnedAttributeData , OwnedMeshAttribute } ,
15+ mesh:: { AttributeData , MeshAttribute } ,
1616 nalgebra:: Vector3 ,
1717} ;
18+ use std:: borrow:: Cow ;
1819
19- fn reconstruction_pipeline_generic < I : Index , R : Real > (
20- particle_positions : & [ Vector3 < R > ] ,
21- attributes : Vec < OwnedMeshAttribute < R > > ,
20+ fn reconstruction_pipeline_generic < ' py , I : Index , R : Real + Element > (
21+ particles : & Bound < ' py , PyArray2 < R > > ,
22+ attributes_to_interpolate : Bound < ' py , PyDict > ,
2223 particle_radius : R ,
2324 rest_density : R ,
2425 smoothing_length : R ,
@@ -55,6 +56,60 @@ fn reconstruction_pipeline_generic<I: Index, R: Real>(
5556 mesh_aabb_max : Option < [ f64 ; 3 ] > ,
5657 mesh_aabb_clamp_vertices : bool ,
5758) -> Result < splashsurf:: reconstruct:: ReconstructionResult < I , R > , anyhow:: Error > {
59+ let particles: PyReadonlyArray2 < R > = particles. readonly ( ) ;
60+ let particle_positions: & [ Vector3 < R > ] = bytemuck:: cast_slice ( particles. as_slice ( ) ?) ;
61+
62+ enum AttributePyView < ' a , R : Real + Element > {
63+ U64 ( PyReadonlyArray1 < ' a , u64 > ) ,
64+ Float ( PyReadonlyArray1 < ' a , R > ) ,
65+ FloatVec3 ( PyReadonlyArray2 < ' a , R > ) ,
66+ }
67+
68+ let mut attr_names = Vec :: new ( ) ;
69+ let mut attr_views = Vec :: new ( ) ;
70+
71+ // Collect readonly views of all attribute arrays
72+ for ( key, value) in attributes_to_interpolate. iter ( ) {
73+ let key_str: String = key
74+ . downcast :: < PyString > ( )
75+ . expect ( "Key wasn't a string" )
76+ . extract ( ) ?;
77+
78+ if let Ok ( value) = value. downcast :: < PyArray1 < u64 > > ( ) {
79+ attr_views. push ( AttributePyView :: U64 ( value. readonly ( ) ) ) ;
80+ attr_names. push ( key_str) ;
81+ } else if let Ok ( value) = value. downcast :: < PyArray1 < R > > ( ) {
82+ attr_views. push ( AttributePyView :: Float ( value. readonly ( ) ) ) ;
83+ attr_names. push ( key_str) ;
84+ } else if let Ok ( value) = value. downcast :: < PyArray2 < R > > ( ) {
85+ attr_views. push ( AttributePyView :: FloatVec3 ( value. readonly ( ) ) ) ;
86+ attr_names. push ( key_str) ;
87+ } else {
88+ println ! ( "Couldn't downcast attribute {} to valid type" , & key_str) ;
89+ }
90+ }
91+
92+ // Get slices from attribute views and construct borrowed MeshAttributes
93+ let attributes = attr_names
94+ . into_iter ( )
95+ . zip ( attr_views. iter ( ) )
96+ . map ( |( name, view) | -> Result < MeshAttribute < R > , anyhow:: Error > {
97+ let data = match view {
98+ AttributePyView :: U64 ( view) => {
99+ AttributeData :: ScalarU64 ( Cow :: Borrowed ( view. as_slice ( ) ?. into ( ) ) )
100+ }
101+ AttributePyView :: Float ( view) => {
102+ AttributeData :: ScalarReal ( Cow :: Borrowed ( view. as_slice ( ) ?. into ( ) ) )
103+ }
104+ AttributePyView :: FloatVec3 ( view) => {
105+ let vec3_slice: & [ Vector3 < R > ] = bytemuck:: cast_slice ( view. as_slice ( ) ?) ;
106+ AttributeData :: Vector3Real ( Cow :: Borrowed ( vec3_slice. into ( ) ) )
107+ }
108+ } ;
109+ Ok ( MeshAttribute :: new ( name, data) )
110+ } )
111+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
112+
58113 let aabb = if let ( Some ( aabb_min) , Some ( aabb_max) ) = ( aabb_min, aabb_max) {
59114 // Convert the min and max arrays to Vector3
60115 Some ( Aabb3d :: new (
@@ -132,55 +187,6 @@ fn reconstruction_pipeline_generic<I: Index, R: Real>(
132187 )
133188}
134189
135- fn attrs_conversion < R : Real + Element > (
136- attributes_to_interpolate : Bound < PyDict > ,
137- ) -> Vec < OwnedMeshAttribute < R > > {
138- let mut attrs: Vec < OwnedMeshAttribute < R > > = Vec :: new ( ) ;
139- for ( key, value) in attributes_to_interpolate. iter ( ) {
140- let key_str: String = key
141- . downcast :: < PyString > ( )
142- . expect ( "Key wasn't a string" )
143- . extract ( )
144- . unwrap ( ) ;
145-
146- if let Ok ( value) = value. downcast :: < PyArray1 < u64 > > ( ) {
147- let value: Vec < u64 > = value
148- . extract :: < PyReadonlyArray1 < u64 > > ( )
149- . unwrap ( )
150- . as_slice ( )
151- . unwrap ( )
152- . to_vec ( ) ;
153- let mesh_attr =
154- OwnedMeshAttribute :: new ( key_str, OwnedAttributeData :: ScalarU64 ( value. into ( ) ) ) ;
155- attrs. push ( mesh_attr) ;
156- } else if let Ok ( value) = value. downcast :: < PyArray1 < R > > ( ) {
157- let value: Vec < R > = value
158- . extract :: < PyReadonlyArray1 < R > > ( )
159- . unwrap ( )
160- . as_slice ( )
161- . unwrap ( )
162- . to_vec ( ) ;
163- let mesh_attr =
164- OwnedMeshAttribute :: new ( key_str, OwnedAttributeData :: ScalarReal ( value. into ( ) ) ) ;
165- attrs. push ( mesh_attr) ;
166- } else if let Ok ( value) = value. downcast :: < PyArray2 < R > > ( ) {
167- let value: PyReadonlyArray2 < R > = value. extract ( ) . unwrap ( ) ;
168-
169- let value_slice = value. as_slice ( ) . unwrap ( ) ;
170- let value_slice: & [ Vector3 < R > ] = bytemuck:: cast_slice ( value_slice) ;
171-
172- let mesh_attr = OwnedMeshAttribute :: new (
173- key_str,
174- OwnedAttributeData :: Vector3Real ( value_slice. to_vec ( ) . into ( ) ) ,
175- ) ;
176- attrs. push ( mesh_attr) ;
177- } else {
178- println ! ( "Couldnt downcast attribute {} to valid type" , & key_str) ;
179- }
180- }
181- attrs
182- }
183-
184190#[ pyfunction]
185191#[ pyo3( name = "reconstruction_pipeline_f32" ) ]
186192#[ pyo3( signature = ( particles, * , attributes_to_interpolate, particle_radius, rest_density,
@@ -237,20 +243,13 @@ pub fn reconstruction_pipeline_py_f32<'py>(
237243 Option < MixedTriQuadMeshWithDataF32 > ,
238244 Option < SurfaceReconstructionF32 > ,
239245) > {
240- let particles: PyReadonlyArray2 < f32 > = particles. extract ( ) ?;
241-
242- let particle_positions = particles. as_slice ( ) ?;
243- let particle_positions: & [ Vector3 < f32 > ] = bytemuck:: cast_slice ( particle_positions) ;
244-
245- let attrs = attrs_conversion ( attributes_to_interpolate) ;
246-
247246 let splashsurf:: reconstruct:: ReconstructionResult {
248247 tri_mesh,
249248 tri_quad_mesh,
250249 raw_reconstruction : reconstruction,
251250 } = reconstruction_pipeline_generic :: < i64 , f32 > (
252- particle_positions ,
253- attrs ,
251+ particles ,
252+ attributes_to_interpolate ,
254253 particle_radius,
255254 rest_density,
256255 smoothing_length,
@@ -286,8 +285,7 @@ pub fn reconstruction_pipeline_py_f32<'py>(
286285 mesh_aabb_min,
287286 mesh_aabb_max,
288287 mesh_aabb_clamp_vertices,
289- )
290- . unwrap ( ) ;
288+ ) ?;
291289
292290 Ok ( (
293291 tri_mesh. map ( TriMeshWithDataF32 :: new) ,
@@ -352,20 +350,13 @@ pub fn reconstruction_pipeline_py_f64<'py>(
352350 Option < MixedTriQuadMeshWithDataF64 > ,
353351 Option < SurfaceReconstructionF64 > ,
354352) > {
355- let particles: PyReadonlyArray2 < f64 > = particles. extract ( ) ?;
356-
357- let particle_positions = particles. as_slice ( ) ?;
358- let particle_positions: & [ Vector3 < f64 > ] = bytemuck:: cast_slice ( particle_positions) ;
359-
360- let attrs = attrs_conversion ( attributes_to_interpolate) ;
361-
362353 let splashsurf:: reconstruct:: ReconstructionResult {
363354 tri_mesh,
364355 tri_quad_mesh,
365356 raw_reconstruction : reconstruction,
366357 } = reconstruction_pipeline_generic :: < i64 , f64 > (
367- particle_positions ,
368- attrs ,
358+ particles ,
359+ attributes_to_interpolate ,
369360 particle_radius,
370361 rest_density,
371362 smoothing_length,
@@ -401,8 +392,7 @@ pub fn reconstruction_pipeline_py_f64<'py>(
401392 mesh_aabb_min,
402393 mesh_aabb_max,
403394 mesh_aabb_clamp_vertices,
404- )
405- . unwrap ( ) ;
395+ ) ?;
406396
407397 Ok ( (
408398 tri_mesh. map ( TriMeshWithDataF64 :: new) ,
0 commit comments