Skip to content

Commit 23809ee

Browse files
committed
Refactor
1 parent 7ec0c3c commit 23809ee

File tree

2 files changed

+66
-51
lines changed

2 files changed

+66
-51
lines changed

pysplashsurf/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ license.workspace = true
88
[dependencies]
99
splashsurf = { path = "../splashsurf" }
1010
splashsurf_lib = { path = "../splashsurf_lib" }
11-
pyo3 = {version = "0.25.0", features = ["anyhow"]}
11+
pyo3 = { version = "0.25.0", features = ["anyhow"] }
1212
numpy = "0.25.0"
1313
ndarray = "0.16.1"
1414
bytemuck = { version = "1.23.0", features = ["extern_crate_alloc"] }

pysplashsurf/src/mesh.rs

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,13 @@ where
2929
let elem = attrs.iter().filter(|x| x.name == name).next();
3030
match elem {
3131
Some(attr) => match attr.data.clone() {
32-
OwnedAttributeData::ScalarU64(res) => {
33-
Ok(res.into_owned().into_pyobject(py).unwrap().into())
34-
}
35-
OwnedAttributeData::ScalarReal(res) => {
36-
Ok(res.into_owned().into_pyobject(py).unwrap().into())
37-
}
32+
OwnedAttributeData::ScalarU64(res) => Ok(res.into_owned().into_pyobject(py)?.into()),
33+
OwnedAttributeData::ScalarReal(res) => Ok(res.into_owned().into_pyobject(py)?.into()),
3834
OwnedAttributeData::Vector3Real(res) => {
3935
let flattened: Vec<R> = bytemuck::cast_vec(res.into_owned());
40-
let res: Array2<R> =
41-
Array2::from_shape_vec((flattened.len() / 3, 3), flattened).unwrap();
42-
Ok(res.into_pyarray(py).into_bound_py_any(py).unwrap().into())
36+
let res: Array2<R> = Array2::from_shape_vec((flattened.len() / 3, 3), flattened)
37+
.map_err(anyhow::Error::new)?;
38+
Ok(res.into_pyarray(py).into_bound_py_any(py)?.into())
4339
}
4440
},
4541
None => Err(PyErr::new::<PyValueError, _>(format!(
@@ -90,13 +86,13 @@ macro_rules! create_mesh_data_interface {
9086
Ok($name::new(meshdata))
9187
}
9288

93-
/// Clone of the contained mesh
89+
/// Returns a copy of the contained mesh
9490
#[getter]
9591
fn mesh(&self) -> $pymesh_class {
9692
$pymesh_class::new(self.inner.mesh.clone())
9793
}
9894

99-
/// Returns mesh without copying the mesh data, removes it from the object
95+
/// Returns the contained mesh by moving it out of this object (zero copy)
10096
fn take_mesh(&mut self) -> $pymesh_class {
10197
let mesh = std::mem::take(&mut self.inner.mesh);
10298
$pymesh_class::new(mesh)
@@ -143,8 +139,8 @@ macro_rules! create_mesh_data_interface {
143139
name: &str,
144140
data: &Bound<'py, PyArray2<$type>>,
145141
) -> PyResult<()> {
146-
let data: PyReadonlyArray2<$type> = data.extract().unwrap();
147-
let data = data.as_slice().unwrap();
142+
let data: PyReadonlyArray2<$type> = data.extract()?;
143+
let data = data.as_slice()?;
148144
let data: &[Vector3<$type>] = bytemuck::cast_slice(data);
149145

150146
add_attribute_with_name::<$type>(
@@ -183,8 +179,8 @@ macro_rules! create_mesh_data_interface {
183179
name: &str,
184180
data: &Bound<'py, PyArray2<$type>>,
185181
) -> PyResult<()> {
186-
let data: PyReadonlyArray2<$type> = data.extract().unwrap();
187-
let data = data.as_slice().unwrap();
182+
let data: PyReadonlyArray2<$type> = data.extract()?;
183+
let data = data.as_slice()?;
188184
let data: &[Vector3<$type>] = bytemuck::cast_slice(data);
189185

190186
add_attribute_with_name::<$type>(
@@ -207,7 +203,7 @@ macro_rules! create_mesh_data_interface {
207203
}
208204

209205
/// Get all point attributes in a python dictionary
210-
fn get_point_attributes<'py>(&self, py: Python<'py>) -> Bound<'py, PyDict> {
206+
fn get_point_attributes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
211207
let res = PyDict::new(py);
212208

213209
for attr in self.inner.point_attributes.iter() {
@@ -217,16 +213,16 @@ macro_rules! create_mesh_data_interface {
217213
&attr.name,
218214
);
219215
match data {
220-
Ok(data) => res.set_item(&attr.name, data).unwrap(),
216+
Ok(data) => res.set_item(&attr.name, data)?,
221217
Err(_) => println!("Couldn't embed attribute {} in PyDict", &attr.name),
222218
}
223219
}
224220

225-
res
221+
Ok(res)
226222
}
227223

228224
/// Get all cell attributes in a python dictionary
229-
fn get_cell_attributes<'py>(&self, py: Python<'py>) -> Bound<'py, PyDict> {
225+
fn get_cell_attributes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
230226
let res = PyDict::new(py);
231227

232228
for attr in self.inner.cell_attributes.iter() {
@@ -236,40 +232,46 @@ macro_rules! create_mesh_data_interface {
236232
&attr.name,
237233
);
238234
match data {
239-
Ok(data) => res.set_item(&attr.name, data).unwrap(),
235+
Ok(data) => res.set_item(&attr.name, data)?,
240236
Err(_) => println!("Couldn't embed attribute {} in PyDict", &attr.name),
241237
}
242238
}
243239

244-
res
240+
Ok(res)
245241
}
246242

247243
/// Get all registered point attribute names
248-
fn get_point_attribute_keys<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
244+
fn get_point_attribute_keys<'py>(
245+
&self,
246+
py: Python<'py>,
247+
) -> PyResult<Bound<'py, PyList>> {
249248
let mut res: Vec<&str> = vec![];
250249

251250
for attr in self.inner.point_attributes.iter() {
252251
res.push(&attr.name);
253252
}
254253

255-
PyList::new(py, res).unwrap()
254+
PyList::new(py, res)
256255
}
257256

258257
/// Get all registered cell attribute names
259-
fn get_cell_attribute_keys<'py>(&self, py: Python<'py>) -> Bound<'py, PyList> {
258+
fn get_cell_attribute_keys<'py>(
259+
&self,
260+
py: Python<'py>,
261+
) -> PyResult<Bound<'py, PyList>> {
260262
let mut res: Vec<&str> = vec![];
261263

262264
for attr in self.inner.cell_attributes.iter() {
263265
res.push(&attr.name);
264266
}
265267

266-
PyList::new(py, res).unwrap()
268+
PyList::new(py, res)
267269
}
268270
}
269271
};
270272
}
271273

272-
macro_rules! create_mesh_interface {
274+
macro_rules! create_tri_mesh_interface {
273275
($name: ident, $type: ident) => {
274276
/// TriMesh3d wrapper
275277
#[gen_stub_pyclass]
@@ -289,24 +291,29 @@ macro_rules! create_mesh_interface {
289291
impl $name {
290292
/// nx3 array of vertex positions, copies the data
291293
#[getter]
292-
fn vertices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<$type>> {
294+
fn vertices<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<$type>>> {
293295
let points: &[$type] = bytemuck::cast_slice(&self.inner.vertices);
294296
let vertices: ArrayView2<$type> =
295-
ArrayView::from_shape((self.inner.vertices.len(), 3), points).unwrap();
296-
vertices.to_pyarray(py) // seems like at least one copy is necessary here (to_pyarray copies the data)
297+
ArrayView::from_shape((self.inner.vertices.len(), 3), points)
298+
.map_err(anyhow::Error::new)?;
299+
Ok(vertices.to_pyarray(py)) // seems like at least one copy is necessary here (to_pyarray copies the data)
297300
}
298301

299302
/// nx3 array of the vertex indices that make up a triangle, copies the data
300303
#[getter]
301-
fn triangles<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<u64>> {
304+
fn triangles<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<u64>>> {
302305
let tris: &[u64] = bytemuck::cast_slice(&self.inner.triangles);
303306
let triangles: ArrayView2<u64> =
304-
ArrayView::from_shape((self.inner.triangles.len(), 3), tris).unwrap();
305-
triangles.to_pyarray(py)
307+
ArrayView::from_shape((self.inner.triangles.len(), 3), tris)
308+
.map_err(anyhow::Error::new)?;
309+
Ok(triangles.to_pyarray(py))
306310
}
307311

308-
/// Returns a tuple of vertices and triangles without copying the data, removes the data in the class
309-
fn take_vertices_and_triangles<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyTuple> {
312+
/// Returns a tuple containing the vertices and triangles by moving them out of the mesh (zero copy)
313+
fn take_vertices_and_triangles<'py>(
314+
&mut self,
315+
py: Python<'py>,
316+
) -> PyResult<Bound<'py, PyTuple>> {
310317
let vertices = std::mem::take(&mut self.inner.vertices);
311318
let triangles = std::mem::take(&mut self.inner.triangles);
312319

@@ -316,28 +323,32 @@ macro_rules! create_mesh_interface {
316323
let vertices_scalar: Vec<$type> = bytemuck::cast_vec(vertices);
317324
let vertices_array = PyArray::from_vec(py, vertices_scalar)
318325
.reshape([n, 3])
319-
.unwrap();
326+
.map_err(anyhow::Error::new)?;
320327

321328
let triangles_scalar: Vec<usize> = bytemuck::cast_vec(triangles);
322329
let triangles_array = PyArray::from_vec(py, triangles_scalar)
323330
.reshape([m, 3])
324-
.unwrap();
331+
.map_err(anyhow::Error::new)?;
325332

326333
let tup = (vertices_array, triangles_array);
327-
tup.into_pyobject(py).unwrap()
334+
tup.into_pyobject(py)
328335
}
329336

330337
/// Computes the mesh's vertex normals using an area weighted average of the adjacent triangle faces (parallelized version)
331-
fn par_vertex_normals<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<$type>> {
338+
fn par_vertex_normals<'py>(
339+
&self,
340+
py: Python<'py>,
341+
) -> PyResult<Bound<'py, PyArray2<$type>>> {
332342
let normals_vec = self.inner.par_vertex_normals();
333343
let normals_vec =
334344
bytemuck::allocation::cast_vec::<Unit<Vector3<$type>>, $type>(normals_vec);
335345

336346
let normals: &[$type] = normals_vec.as_slice();
337347
let normals: ArrayView2<$type> =
338-
ArrayView::from_shape((normals.len() / 3, 3), normals).unwrap();
348+
ArrayView::from_shape((normals.len() / 3, 3), normals)
349+
.map_err(anyhow::Error::new)?;
339350

340-
normals.to_pyarray(py)
351+
Ok(normals.to_pyarray(py))
341352
}
342353

343354
/// Returns a mapping of all mesh vertices to the set of their connected neighbor vertices
@@ -368,16 +379,17 @@ macro_rules! create_tri_quad_mesh_interface {
368379
impl $name {
369380
/// nx3 array of vertex positions, copies data
370381
#[getter]
371-
fn vertices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<$type>> {
382+
fn vertices<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<$type>>> {
372383
let points: &[$type] = bytemuck::cast_slice(&self.inner.vertices);
373384
let vertices: ArrayView2<$type> =
374-
ArrayView::from_shape((self.inner.vertices.len(), 3), points).unwrap();
375-
vertices.to_pyarray(py)
385+
ArrayView::from_shape((self.inner.vertices.len(), 3), points)
386+
.map_err(anyhow::Error::new)?;
387+
Ok(vertices.to_pyarray(py))
376388
}
377389

378390
/// 2D list specifying the vertex indices either for a triangle or a quad
379391
#[getter]
380-
fn cells(&self) -> Vec<Vec<usize>> {
392+
fn cells(&self) -> PyResult<Vec<Vec<usize>>> {
381393
let cells: Vec<Vec<usize>> = self
382394
.inner
383395
.cells
@@ -387,11 +399,14 @@ macro_rules! create_tri_quad_mesh_interface {
387399
TriangleOrQuadCell::Quad(v) => v.to_vec(),
388400
})
389401
.collect();
390-
cells
402+
Ok(cells)
391403
}
392404

393405
/// Returns a tuple of vertices and triangles without copying the data, removes the data in the class
394-
fn take_vertices_and_cells<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyTuple> {
406+
fn take_vertices_and_cells<'py>(
407+
&mut self,
408+
py: Python<'py>,
409+
) -> PyResult<Bound<'py, PyTuple>> {
395410
let vertices = std::mem::take(&mut self.inner.vertices);
396411
let cells = std::mem::take(&mut self.inner.cells);
397412

@@ -400,7 +415,7 @@ macro_rules! create_tri_quad_mesh_interface {
400415
let vertices_scalar: Vec<$type> = bytemuck::cast_vec(vertices);
401416
let vertices_array = PyArray::from_vec(py, vertices_scalar)
402417
.reshape([n, 3])
403-
.unwrap();
418+
.map_err(anyhow::Error::new)?;
404419

405420
let cells_list: Vec<Vec<usize>> = cells
406421
.into_iter()
@@ -411,14 +426,14 @@ macro_rules! create_tri_quad_mesh_interface {
411426
.collect();
412427

413428
let tup = (vertices_array, cells_list);
414-
tup.into_pyobject(py).unwrap()
429+
tup.into_pyobject(py)
415430
}
416431
}
417432
};
418433
}
419434

420-
create_mesh_interface!(TriMesh3dF64, f64);
421-
create_mesh_interface!(TriMesh3dF32, f32);
435+
create_tri_mesh_interface!(TriMesh3dF64, f64);
436+
create_tri_mesh_interface!(TriMesh3dF32, f32);
422437

423438
create_tri_quad_mesh_interface!(MixedTriQuadMesh3dF64, f64);
424439
create_tri_quad_mesh_interface!(MixedTriQuadMesh3dF32, f32);

0 commit comments

Comments
 (0)