Skip to content

Commit fb2c38f

Browse files
committed
perf(speed): memcpy directly for np.array to std::vector<Eigen::Vector3d>
1 parent d1490a3 commit fb2c38f

File tree

8 files changed

+127
-44
lines changed

8 files changed

+127
-44
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ Running on macOS, Windows and Linux, with Python Version >= 3.8.
1010
Available in: <a href="https://github.com/Kin-Zhang/linefit"><img src="https://img.shields.io/badge/Windows-0078D6?st&logo=windows&logoColor=white" /> <a href="https://github.com/Kin-Zhang/linefit"><img src="https://img.shields.io/badge/Linux-FCC624?logo=linux&logoColor=black" /> <a href="https://github.com/Kin-Zhang/linefit"><img src="https://img.shields.io/badge/mac%20os-000000?&logo=apple&logoColor=white" /> </a>
1111

1212
<!-- -->
13+
📜 Change Log:
14+
- 2024-07-03: Speed up nanobind `np.array` <-> `std::vector<Eigen:: Vector3d>` conversion and also `NOMINSIZE` in make. Speed difference: 0.1s -> 0.01s. Based on [discussion here](https://github.com/wjakob/nanobind/discussions/426).
15+
- 2024-02-15: Initial version.
1316

1417
## 0. Setup
1518

cpp/linefit/ground_segmentation.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,7 @@ GroundSegmentation::GroundSegmentation(const std::string &toml_file) {
6363

6464
}
6565

66-
std::vector<bool> GroundSegmentation::segment(const std::vector<std::vector<float>> points) {
67-
// TODO: Maybe there is a better way to convert the points to Eigen::Vector3d
68-
PointCloud cloud;
69-
for (auto point : points) {
70-
cloud.push_back(Eigen::Vector3d(point[0], point[1], point[2]));
71-
}
66+
std::vector<bool> GroundSegmentation::segment(const PointCloud &cloud) {
7267
if (verbose_)
7368
std::cout << "Segmenting cloud with " << cloud.size() << " points...\n";
7469

cpp/linefit/ground_segmentation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ class GroundSegmentation {
101101
GroundSegmentation(const std::string& toml_file);
102102
// virtual ~GroundSegmentation() = default;
103103

104-
std::vector<bool> segment(const std::vector<std::vector<float>> points);
104+
std::vector<bool> segment(const PointCloud &cloud);
105105

106106
};

example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
groundseg = ground_seg()
2727
else:
2828
groundseg = ground_seg(config_path)
29-
label = np.array(groundseg.run(pc_data[:,:3].tolist()))
29+
label = np.array(groundseg.run(pc_data[:,:3]))
3030
print(f"point cloud shape: {pc_data[:, :3].shape}, label shape: {label.shape}, ground points: {np.sum(label)}, time: {time.time() - start_time:.3f} s")
3131

3232
from python.utils.o3d_view import MyVisualizer

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
44

55
[project]
66
name = "linefit"
7-
version = "0.2.4"
7+
version = "1.0.0"
88
description = "linefit ground segmentation algorithm Python bindings"
99
readme = "README.md"
1010
requires-python = ">=3.8"

python/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
nanobind_add_module(linefit MODULE linefit_ext.cpp)
1+
nanobind_add_module(linefit MODULE NOMINSIZE linefit_ext.cpp)
22
target_link_libraries(linefit PRIVATE gm_lib)
33
install(TARGETS linefit LIBRARY DESTINATION .)

python/linefit_ext.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,28 @@
3939
*/
4040

4141
#include <Eigen/Core>
42-
#include "nanobind/nanobind.h"
42+
43+
#include <nanobind/nanobind.h>
4344
#include <nanobind/stl/string.h>
4445
#include <nanobind/stl/vector.h>
4546
#include <nanobind/ndarray.h>
46-
// #include <nanobind/stl/bind_vector.h>
4747

4848
#include "../cpp/linefit/ground_segmentation.h"
4949
#include "../cpp/linefit/mics.h"
50-
// #include "stl_vector_eigen.h"
5150

5251
namespace nb = nanobind;
5352
using namespace nb::literals;
5453

5554
NB_MODULE(linefit, m) {
5655
nb::class_<GroundSegmentation>(m, "ground_seg")
57-
.def(nb::init<>(), "linefit ground segmentation constructor, param: TODO")
56+
.def(nb::init<>(), "linefit ground segmentation constructor, param: check default config to know more.")
5857
.def(nb::init<const std::string &>(), "linefit ground segmentation constructor, with toml file as param file input.")
59-
// .def("run", nb::overload_cast<std::vector<Eigen::Vector3d> &>(&GroundSegmentation::segment), "points"_a);
60-
.def("run", &GroundSegmentation::segment, "points"_a, nanobind::rv_policy::reference);
61-
}
58+
.def("run", [](GroundSegmentation& self, const nb::ndarray<double>& array) -> std::vector<bool> {
59+
if (array.ndim() != 2 || array.shape(1) != 3) {
60+
throw std::runtime_error("Input array must have shape (N, 3)");
61+
}
62+
std::vector<Eigen::Vector3d> points_vec(array.shape(0));
63+
std::memcpy(points_vec.data(), array.data(), array.size() * sizeof(double));
64+
return self.segment(points_vec);
65+
}, "points"_a);
66+
}

python/utils/o3d_view.py

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
'''
22
# Created: 2023-1-26 16:38
3-
# Copyright (C) 2022-now, RPL, KTH Royal Institute of Technology
4-
# Author: Kin ZHANG (https://kin-zhang.github.io/)
5-
6-
# This work is licensed under the terms of the MIT license.
7-
# For a copy, see <https://opensource.org/licenses/MIT>.
3+
# Updated: 2024-04-15 12:06
4+
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
5+
# Author: Qingwen ZHANG (https://kin-zhang.github.io/)
6+
#
7+
# code gits: https://gist.github.com/Kin-Zhang/77e8aa77a998f1a4f7495357843f24ef
8+
#
9+
# Description as follows:
810
911
This file is for open3d view control set from view_file, which should be json
1012
1. use normal way to open any geometry and set view by mouse you want
1113
2. `CTRL+C` it will copy the view detail at this moment.
1214
3. `CTRL+V` to json file, you can create new one
1315
4. give the json file path
14-
Check this part: http://www.open3d.org/docs/release/tutorial/visualization/visualization.html#Store-view-point
1516
16-
Test if you want by run this script
17+
Check this part: http://www.open3d.org/docs/release/tutorial/visualization/visualization.html#Store-view-point
1718
18-
Then press 'V' on keyboard, will set from json
19+
Test if you want by run this script: by press 'V' on keyboard, will set from json
1920
21+
# CHANGELOG:
22+
# 2024-04-15 12:06(Qingwen): show a example json text. add hex_to_rgb, color_map_hex, color_map (for color points if needed)
23+
# 2024-01-27 0:41(Qingwen): update MyVisualizer class, reference from kiss-icp
24+
[python/kiss-icp/tools/visualizer.py](https://github.com/PRBonn/kiss-icp/blob/main/python/kiss_icp/tools/visualizer.py)
2025
'''
2126

2227
import open3d as o3d
2328
import json
2429
import os, sys
25-
from typing import List
26-
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' ))
27-
sys.path.append(BASE_DIR)
30+
from typing import List, Callable
31+
from functools import partial
2832

2933
def hex_to_rgb(hex_color):
3034
hex_color = hex_color.lstrip("#")
3135
return tuple(int(hex_color[i:i + 2], 16) / 255.0 for i in (0, 2, 4))
3236

33-
color_map_hex = ['#a6cee3','#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c','#fdbf6f','#ff7f00','#cab2d6','#6a3d9a','#ffff99','#b15928',\
37+
color_map_hex = ['#a6cee3', '#de2d26', '#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c','#fdbf6f','#ff7f00','#cab2d6','#6a3d9a','#ffff99','#b15928',\
3438
'#8dd3c7','#ffffb3','#bebada','#fb8072','#80b1d3','#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd','#ccebc5','#ffed6f']
3539
color_map = [hex_to_rgb(color) for color in color_map_hex]
3640

@@ -40,16 +44,24 @@ def __init__(self, vctrl: o3d.visualization.ViewControl, view_file=None):
4044
self.params = None
4145
if view_file is not None:
4246
print(f"Init with view_file from: {view_file}")
43-
self.parase_file(view_file)
47+
self.parse_file(view_file)
4448
self.set_param()
4549
else:
4650
print("Init without view_file")
51+
4752
def read_viewTfile(self, view_file):
48-
self.parase_file(view_file)
53+
if view_file is None:
54+
return
55+
self.parse_file(view_file)
4956
self.set_param()
57+
5058
def save_viewTfile(self, view_file):
5159
return
52-
def parase_file(self, view_file):
60+
61+
def parse_file(self, view_file):
62+
if view_file is None:
63+
print(f"\033[91mNo specific view file. Skip to setup viewpoint in open3d. \033[0m")
64+
return
5365
if(os.path.exists(view_file)):
5466
with open((view_file)) as user_file:
5567
file_contents = user_file.read()
@@ -69,26 +81,94 @@ def set_param(self):
6981
class MyVisualizer:
7082
def __init__(self, view_file=None, window_title="Default"):
7183
self.params = None
72-
self.viz = o3d.visualization.VisualizerWithKeyCallback()
73-
self.viz.create_window(window_name=window_title)
74-
self.o3d_vctrl = ViewControl(self.viz.get_view_control(), view_file=view_file)
84+
self.vis = o3d.visualization.VisualizerWithKeyCallback()
85+
self.vis.create_window(window_name=window_title)
86+
self.o3d_vctrl = ViewControl(self.vis.get_view_control(), view_file=view_file)
7587
self.view_file = view_file
76-
88+
89+
self.block_vis = True
90+
self.play_crun = False
91+
self.reset_bounding_box = True
92+
print(
93+
f"\n{window_title.capitalize()} initialized. Press:\n"
94+
"\t[SPACE] to pause/start\n"
95+
"\t [ESC] to exit\n"
96+
"\t [N] to step\n"
97+
)
98+
self._register_key_callback(["Ā", "Q", "\x1b"], self._quit)
99+
self._register_key_callback([" "], self._start_stop)
100+
self._register_key_callback(["N"], self._next_frame)
101+
77102
def show(self, assets: List):
78-
self.viz.clear_geometries()
103+
self.vis.clear_geometries()
79104

80105
for asset in assets:
81-
self.viz.add_geometry(asset)
82-
if self.view_file is not None:
106+
self.vis.add_geometry(asset)
83107
self.o3d_vctrl.read_viewTfile(self.view_file)
84108

85-
self.viz.update_renderer()
86-
self.viz.poll_events()
87-
self.viz.run()
88-
self.viz.destroy_window()
109+
self.vis.update_renderer()
110+
self.vis.poll_events()
111+
self.vis.run()
112+
self.vis.destroy_window()
113+
114+
def update(self, assets: List, clear: bool = True):
115+
if clear:
116+
self.vis.clear_geometries()
117+
118+
for asset in assets:
119+
self.vis.add_geometry(asset, reset_bounding_box=False)
120+
self.vis.update_geometry(asset)
121+
122+
if self.reset_bounding_box:
123+
self.vis.reset_view_point(True)
124+
if self.view_file is not None:
125+
self.o3d_vctrl.read_viewTfile(self.view_file)
126+
self.reset_bounding_box = False
127+
128+
self.vis.update_renderer()
129+
while self.block_vis:
130+
self.vis.poll_events()
131+
if self.play_crun:
132+
break
133+
self.block_vis = not self.block_vis
134+
135+
def _register_key_callback(self, keys: List, callback: Callable):
136+
for key in keys:
137+
self.vis.register_key_callback(ord(str(key)), partial(callback))
138+
def _next_frame(self, vis):
139+
self.block_vis = not self.block_vis
140+
def _start_stop(self, vis):
141+
self.play_crun = not self.play_crun
142+
def _quit(self, vis):
143+
print("Destroying Visualizer. Thanks for using ^v^.")
144+
vis.destroy_window()
145+
os._exit(0)
89146

90147
if __name__ == "__main__":
91-
view_json_file = f"{BASE_DIR}/data/o3d_view/default_test.json"
148+
json_content = """{
149+
"class_name" : "ViewTrajectory",
150+
"interval" : 29,
151+
"is_loop" : false,
152+
"trajectory" :
153+
[
154+
{
155+
"boundingbox_max" : [ 3.9660897254943848, 2.427476167678833, 2.55859375 ],
156+
"boundingbox_min" : [ 0.55859375, 0.83203125, 0.56663715839385986 ],
157+
"field_of_view" : 60.0,
158+
"front" : [ 0.27236083595988803, -0.25567329763523589, -0.92760484038816615 ],
159+
"lookat" : [ 2.4114965637897101, 1.8070288935660688, 1.5662280268112718 ],
160+
"up" : [ -0.072779625398507866, -0.96676294585190281, 0.24509698622097265 ],
161+
"zoom" : 0.47999999999999976
162+
}
163+
],
164+
"version_major" : 1,
165+
"version_minor" : 0
166+
}
167+
"""
168+
# write to json file
169+
view_json_file = "view.json"
170+
with open(view_json_file, 'w') as f:
171+
f.write(json_content)
92172
sample_ply_data = o3d.data.PLYPointCloud()
93173
pcd = o3d.io.read_point_cloud(sample_ply_data.path)
94174
# 1. define

0 commit comments

Comments
 (0)