Skip to content

Commit 7ec0c3c

Browse files
committed
Support reconstruction with non-owned attributes
1 parent 7e09c81 commit 7ec0c3c

File tree

2 files changed

+72
-82
lines changed

2 files changed

+72
-82
lines changed

pysplashsurf/src/pipeline.rs

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@ use crate::{
55
},
66
reconstruction::{SurfaceReconstructionF32, SurfaceReconstructionF64},
77
};
8-
use numpy::{Element, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
8+
use numpy::{Element, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
99
use pyo3::{
1010
prelude::*,
1111
types::{PyDict, PyString},
1212
};
1313
use splashsurf_lib::{
1414
Aabb3d, GridDecompositionParameters, Index, Real, SpatialDecomposition,
15-
mesh::{OwnedAttributeData, OwnedMeshAttribute},
15+
mesh::{AttributeData, MeshAttribute},
1616
nalgebra::Vector3,
1717
};
18+
use std::borrow::Cow;
1819

19-
fn reconstruction_pipeline_generic<I: Index, R: Real>(
20-
particle_positions: &[Vector3<R>],
21-
attributes: Vec<OwnedMeshAttribute<R>>,
20+
fn reconstruction_pipeline_generic<'py, I: Index, R: Real + Element>(
21+
particles: &Bound<'py, PyArray2<R>>,
22+
attributes_to_interpolate: Bound<'py, PyDict>,
2223
particle_radius: R,
2324
rest_density: R,
2425
smoothing_length: R,
@@ -55,6 +56,60 @@ fn reconstruction_pipeline_generic<I: Index, R: Real>(
5556
mesh_aabb_max: Option<[f64; 3]>,
5657
mesh_aabb_clamp_vertices: bool,
5758
) -> Result<splashsurf::reconstruct::ReconstructionResult<I, R>, anyhow::Error> {
59+
let particles: PyReadonlyArray2<R> = particles.readonly();
60+
let particle_positions: &[Vector3<R>] = bytemuck::cast_slice(particles.as_slice()?);
61+
62+
enum AttributePyView<'a, R: Real + Element> {
63+
U64(PyReadonlyArray1<'a, u64>),
64+
Float(PyReadonlyArray1<'a, R>),
65+
FloatVec3(PyReadonlyArray2<'a, R>),
66+
}
67+
68+
let mut attr_names = Vec::new();
69+
let mut attr_views = Vec::new();
70+
71+
// Collect readonly views of all attribute arrays
72+
for (key, value) in attributes_to_interpolate.iter() {
73+
let key_str: String = key
74+
.downcast::<PyString>()
75+
.expect("Key wasn't a string")
76+
.extract()?;
77+
78+
if let Ok(value) = value.downcast::<PyArray1<u64>>() {
79+
attr_views.push(AttributePyView::U64(value.readonly()));
80+
attr_names.push(key_str);
81+
} else if let Ok(value) = value.downcast::<PyArray1<R>>() {
82+
attr_views.push(AttributePyView::Float(value.readonly()));
83+
attr_names.push(key_str);
84+
} else if let Ok(value) = value.downcast::<PyArray2<R>>() {
85+
attr_views.push(AttributePyView::FloatVec3(value.readonly()));
86+
attr_names.push(key_str);
87+
} else {
88+
println!("Couldn't downcast attribute {} to valid type", &key_str);
89+
}
90+
}
91+
92+
// Get slices from attribute views and construct borrowed MeshAttributes
93+
let attributes = attr_names
94+
.into_iter()
95+
.zip(attr_views.iter())
96+
.map(|(name, view)| -> Result<MeshAttribute<R>, anyhow::Error> {
97+
let data = match view {
98+
AttributePyView::U64(view) => {
99+
AttributeData::ScalarU64(Cow::Borrowed(view.as_slice()?.into()))
100+
}
101+
AttributePyView::Float(view) => {
102+
AttributeData::ScalarReal(Cow::Borrowed(view.as_slice()?.into()))
103+
}
104+
AttributePyView::FloatVec3(view) => {
105+
let vec3_slice: &[Vector3<R>] = bytemuck::cast_slice(view.as_slice()?);
106+
AttributeData::Vector3Real(Cow::Borrowed(vec3_slice.into()))
107+
}
108+
};
109+
Ok(MeshAttribute::new(name, data))
110+
})
111+
.collect::<Result<Vec<_>, _>>()?;
112+
58113
let aabb = if let (Some(aabb_min), Some(aabb_max)) = (aabb_min, aabb_max) {
59114
// Convert the min and max arrays to Vector3
60115
Some(Aabb3d::new(
@@ -132,55 +187,6 @@ fn reconstruction_pipeline_generic<I: Index, R: Real>(
132187
)
133188
}
134189

135-
fn attrs_conversion<R: Real + Element>(
136-
attributes_to_interpolate: Bound<PyDict>,
137-
) -> Vec<OwnedMeshAttribute<R>> {
138-
let mut attrs: Vec<OwnedMeshAttribute<R>> = Vec::new();
139-
for (key, value) in attributes_to_interpolate.iter() {
140-
let key_str: String = key
141-
.downcast::<PyString>()
142-
.expect("Key wasn't a string")
143-
.extract()
144-
.unwrap();
145-
146-
if let Ok(value) = value.downcast::<PyArray1<u64>>() {
147-
let value: Vec<u64> = value
148-
.extract::<PyReadonlyArray1<u64>>()
149-
.unwrap()
150-
.as_slice()
151-
.unwrap()
152-
.to_vec();
153-
let mesh_attr =
154-
OwnedMeshAttribute::new(key_str, OwnedAttributeData::ScalarU64(value.into()));
155-
attrs.push(mesh_attr);
156-
} else if let Ok(value) = value.downcast::<PyArray1<R>>() {
157-
let value: Vec<R> = value
158-
.extract::<PyReadonlyArray1<R>>()
159-
.unwrap()
160-
.as_slice()
161-
.unwrap()
162-
.to_vec();
163-
let mesh_attr =
164-
OwnedMeshAttribute::new(key_str, OwnedAttributeData::ScalarReal(value.into()));
165-
attrs.push(mesh_attr);
166-
} else if let Ok(value) = value.downcast::<PyArray2<R>>() {
167-
let value: PyReadonlyArray2<R> = value.extract().unwrap();
168-
169-
let value_slice = value.as_slice().unwrap();
170-
let value_slice: &[Vector3<R>] = bytemuck::cast_slice(value_slice);
171-
172-
let mesh_attr = OwnedMeshAttribute::new(
173-
key_str,
174-
OwnedAttributeData::Vector3Real(value_slice.to_vec().into()),
175-
);
176-
attrs.push(mesh_attr);
177-
} else {
178-
println!("Couldnt downcast attribute {} to valid type", &key_str);
179-
}
180-
}
181-
attrs
182-
}
183-
184190
#[pyfunction]
185191
#[pyo3(name = "reconstruction_pipeline_f32")]
186192
#[pyo3(signature = (particles, *, attributes_to_interpolate, particle_radius, rest_density,
@@ -237,20 +243,13 @@ pub fn reconstruction_pipeline_py_f32<'py>(
237243
Option<MixedTriQuadMeshWithDataF32>,
238244
Option<SurfaceReconstructionF32>,
239245
)> {
240-
let particles: PyReadonlyArray2<f32> = particles.extract()?;
241-
242-
let particle_positions = particles.as_slice()?;
243-
let particle_positions: &[Vector3<f32>] = bytemuck::cast_slice(particle_positions);
244-
245-
let attrs = attrs_conversion(attributes_to_interpolate);
246-
247246
let splashsurf::reconstruct::ReconstructionResult {
248247
tri_mesh,
249248
tri_quad_mesh,
250249
raw_reconstruction: reconstruction,
251250
} = reconstruction_pipeline_generic::<i64, f32>(
252-
particle_positions,
253-
attrs,
251+
particles,
252+
attributes_to_interpolate,
254253
particle_radius,
255254
rest_density,
256255
smoothing_length,
@@ -286,8 +285,7 @@ pub fn reconstruction_pipeline_py_f32<'py>(
286285
mesh_aabb_min,
287286
mesh_aabb_max,
288287
mesh_aabb_clamp_vertices,
289-
)
290-
.unwrap();
288+
)?;
291289

292290
Ok((
293291
tri_mesh.map(TriMeshWithDataF32::new),
@@ -352,20 +350,13 @@ pub fn reconstruction_pipeline_py_f64<'py>(
352350
Option<MixedTriQuadMeshWithDataF64>,
353351
Option<SurfaceReconstructionF64>,
354352
)> {
355-
let particles: PyReadonlyArray2<f64> = particles.extract()?;
356-
357-
let particle_positions = particles.as_slice()?;
358-
let particle_positions: &[Vector3<f64>] = bytemuck::cast_slice(particle_positions);
359-
360-
let attrs = attrs_conversion(attributes_to_interpolate);
361-
362353
let splashsurf::reconstruct::ReconstructionResult {
363354
tri_mesh,
364355
tri_quad_mesh,
365356
raw_reconstruction: reconstruction,
366357
} = reconstruction_pipeline_generic::<i64, f64>(
367-
particle_positions,
368-
attrs,
358+
particles,
359+
attributes_to_interpolate,
369360
particle_radius,
370361
rest_density,
371362
smoothing_length,
@@ -401,8 +392,7 @@ pub fn reconstruction_pipeline_py_f64<'py>(
401392
mesh_aabb_min,
402393
mesh_aabb_max,
403394
mesh_aabb_clamp_vertices,
404-
)
405-
.unwrap();
395+
)?;
406396

407397
Ok((
408398
tri_mesh.map(TriMeshWithDataF64::new),

splashsurf/src/reconstruct.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use indicatif::{ProgressBar, ProgressStyle};
99
use log::{error, info, warn};
1010
use rayon::prelude::*;
1111
use splashsurf_lib::mesh::{
12-
Mesh3d, MeshWithData, MixedTriQuadMesh3d, OwnedAttributeData, OwnedMeshAttribute, TriMesh3d,
12+
AttributeData, Mesh3d, MeshAttribute, MeshWithData, MixedTriQuadMesh3d, OwnedAttributeData,
13+
OwnedMeshAttribute, TriMesh3d,
1314
};
1415
use splashsurf_lib::nalgebra::{Unit, Vector3};
1516
use splashsurf_lib::sph_interpolation::SphInterpolator;
@@ -18,7 +19,6 @@ use std::borrow::Cow;
1819
use std::collections::HashMap;
1920
use std::convert::TryFrom;
2021
use std::path::PathBuf;
21-
2222
// TODO: Detect smallest index type (i.e. check if ok to use i32 as index)
2323

2424
static ARGS_IO: &str = "Input/output";
@@ -1007,9 +1007,9 @@ pub(crate) fn reconstruction_pipeline_from_args(
10071007
/// * `attributes`: Note that the attributes are not required for the reconstruction itself but can be used for
10081008
/// post-processing steps like attribute interpolation to the reconstructed surface. This has to be enabled explicitly
10091009
/// in the post-processing parameters (see [`ReconstructionPostprocessingParameters`]).
1010-
pub fn reconstruction_pipeline<I: Index, R: Real>(
1010+
pub fn reconstruction_pipeline<'a, I: Index, R: Real>(
10111011
particle_positions: &[Vector3<R>],
1012-
attributes: &[OwnedMeshAttribute<R>],
1012+
attributes: &[MeshAttribute<'a, R>],
10131013
params: &splashsurf_lib::Parameters<R>,
10141014
postprocessing: &ReconstructionPostprocessingParameters,
10151015
) -> Result<ReconstructionResult<I, R>, anyhow::Error> {
@@ -1340,7 +1340,7 @@ pub fn reconstruction_pipeline<I: Index, R: Real>(
13401340

13411341
let particles_inside = reconstruction.particle_inside_aabb().map(Vec::as_slice);
13421342
match &attribute.data {
1343-
OwnedAttributeData::ScalarReal(values) => {
1343+
AttributeData::ScalarReal(values) => {
13441344
let filtered_values = filtered_quantity(&values, particles_inside);
13451345
let interpolated_values = interpolator.interpolate_scalar_quantity(
13461346
&filtered_values,
@@ -1354,7 +1354,7 @@ pub fn reconstruction_pipeline<I: Index, R: Real>(
13541354
OwnedAttributeData::ScalarReal(interpolated_values.into()),
13551355
));
13561356
}
1357-
OwnedAttributeData::Vector3Real(values) => {
1357+
AttributeData::Vector3Real(values) => {
13581358
let filtered_values = filtered_quantity(&values, particles_inside);
13591359
let interpolated_values = interpolator.interpolate_vector_quantity(
13601360
&filtered_values,

0 commit comments

Comments
 (0)