Skip to content

Commit d64b847

Browse files
committed
Used formatting tool.
1 parent e0c78f5 commit d64b847

File tree

10 files changed

+501
-382
lines changed

10 files changed

+501
-382
lines changed

elevation_mapping_cupy/script/custom_kernels.py

Lines changed: 166 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@
66
import string
77

88

9-
def map_utils(resolution, width, height, sensor_noise_factor, min_valid_distance, max_height_range,
10-
ramped_height_range_a, ramped_height_range_b, ramped_height_range_c):
11-
util_preamble = string.Template('''
9+
def map_utils(
10+
resolution,
11+
width,
12+
height,
13+
sensor_noise_factor,
14+
min_valid_distance,
15+
max_height_range,
16+
ramped_height_range_a,
17+
ramped_height_range_b,
18+
ramped_height_range_c,
19+
):
20+
util_preamble = string.Template(
21+
"""
1222
__device__ float16 clamp(float16 x, float16 min_x, float16 max_x) {
1323
1424
return max(min(x, max_x), min_x);
@@ -97,32 +107,57 @@ def map_utils(resolution, width, height, sensor_noise_factor, min_valid_distance
97107
return product;
98108
}
99109
100-
''').substitute(resolution=resolution, width=width, height=height,
101-
sensor_noise_factor=sensor_noise_factor,
102-
min_valid_distance=min_valid_distance,
103-
max_height_range=max_height_range,
104-
ramped_height_range_a=ramped_height_range_a,
105-
ramped_height_range_b=ramped_height_range_b,
106-
ramped_height_range_c=ramped_height_range_c,
107-
)
110+
"""
111+
).substitute(
112+
resolution=resolution,
113+
width=width,
114+
height=height,
115+
sensor_noise_factor=sensor_noise_factor,
116+
min_valid_distance=min_valid_distance,
117+
max_height_range=max_height_range,
118+
ramped_height_range_a=ramped_height_range_a,
119+
ramped_height_range_b=ramped_height_range_b,
120+
ramped_height_range_c=ramped_height_range_c,
121+
)
108122
return util_preamble
109123

110124

111-
def add_points_kernel(resolution, width, height, sensor_noise_factor,
112-
mahalanobis_thresh, outlier_variance, wall_num_thresh,
113-
max_ray_length, cleanup_step, min_valid_distance,
114-
max_height_range, cleanup_cos_thresh,
115-
ramped_height_range_a, ramped_height_range_b, ramped_height_range_c,
116-
enable_edge_shaped=True, enable_visibility_cleanup=True):
125+
def add_points_kernel(
126+
resolution,
127+
width,
128+
height,
129+
sensor_noise_factor,
130+
mahalanobis_thresh,
131+
outlier_variance,
132+
wall_num_thresh,
133+
max_ray_length,
134+
cleanup_step,
135+
min_valid_distance,
136+
max_height_range,
137+
cleanup_cos_thresh,
138+
ramped_height_range_a,
139+
ramped_height_range_b,
140+
ramped_height_range_c,
141+
enable_edge_shaped=True,
142+
enable_visibility_cleanup=True,
143+
):
117144

118145
add_points_kernel = cp.ElementwiseKernel(
119-
in_params='raw U p, raw U center_x, raw U center_y, raw U R, raw U t, raw U norm_map',
120-
out_params='raw U map, raw T newmap',
121-
preamble=map_utils(resolution, width, height, sensor_noise_factor, min_valid_distance, max_height_range,
122-
ramped_height_range_a, ramped_height_range_b, ramped_height_range_c),
123-
operation=\
124-
string.Template(
125-
'''
146+
in_params="raw U p, raw U center_x, raw U center_y, raw U R, raw U t, raw U norm_map",
147+
out_params="raw U map, raw T newmap",
148+
preamble=map_utils(
149+
resolution,
150+
width,
151+
height,
152+
sensor_noise_factor,
153+
min_valid_distance,
154+
max_height_range,
155+
ramped_height_range_a,
156+
ramped_height_range_b,
157+
ramped_height_range_c,
158+
),
159+
operation=string.Template(
160+
"""
126161
U rx = p[i * 3];
127162
U ry = p[i * 3 + 1];
128163
U rz = p[i * 3 + 2];
@@ -223,33 +258,54 @@ def add_points_kernel(resolution, width, height, sensor_noise_factor,
223258
}
224259
}
225260
}
226-
''').substitute(mahalanobis_thresh=mahalanobis_thresh,
227-
outlier_variance=outlier_variance,
228-
wall_num_thresh=wall_num_thresh,
229-
ray_step=resolution / 2**0.5,
230-
max_ray_length=max_ray_length,
231-
cleanup_step=cleanup_step,
232-
cleanup_cos_thresh=cleanup_cos_thresh,
233-
enable_edge_shaped=int(enable_edge_shaped),
234-
enable_visibility_cleanup=int(enable_visibility_cleanup)),
235-
name='add_points_kernel')
261+
"""
262+
).substitute(
263+
mahalanobis_thresh=mahalanobis_thresh,
264+
outlier_variance=outlier_variance,
265+
wall_num_thresh=wall_num_thresh,
266+
ray_step=resolution / 2**0.5,
267+
max_ray_length=max_ray_length,
268+
cleanup_step=cleanup_step,
269+
cleanup_cos_thresh=cleanup_cos_thresh,
270+
enable_edge_shaped=int(enable_edge_shaped),
271+
enable_visibility_cleanup=int(enable_visibility_cleanup),
272+
),
273+
name="add_points_kernel",
274+
)
236275
return add_points_kernel
237276

238277

239-
def error_counting_kernel(resolution, width, height, sensor_noise_factor,
240-
mahalanobis_thresh, outlier_variance,
241-
traversability_inlier, min_valid_distance, max_height_range,
242-
ramped_height_range_a, ramped_height_range_b, ramped_height_range_c,
243-
):
278+
def error_counting_kernel(
279+
resolution,
280+
width,
281+
height,
282+
sensor_noise_factor,
283+
mahalanobis_thresh,
284+
outlier_variance,
285+
traversability_inlier,
286+
min_valid_distance,
287+
max_height_range,
288+
ramped_height_range_a,
289+
ramped_height_range_b,
290+
ramped_height_range_c,
291+
):
244292

245293
error_counting_kernel = cp.ElementwiseKernel(
246-
in_params='raw U map, raw U p, raw U center_x, raw U center_y, raw U R, raw U t',
247-
out_params='raw U newmap, raw T error, raw T error_cnt',
248-
preamble=map_utils(resolution, width, height, sensor_noise_factor, min_valid_distance, max_height_range,
249-
ramped_height_range_a, ramped_height_range_b, ramped_height_range_c),
250-
operation=\
251-
string.Template(
252-
'''
294+
in_params="raw U map, raw U p, raw U center_x, raw U center_y, raw U R, raw U t",
295+
out_params="raw U newmap, raw T error, raw T error_cnt",
296+
preamble=map_utils(
297+
resolution,
298+
width,
299+
height,
300+
sensor_noise_factor,
301+
min_valid_distance,
302+
max_height_range,
303+
ramped_height_range_a,
304+
ramped_height_range_b,
305+
ramped_height_range_c,
306+
),
307+
operation=string.Template(
308+
"""
253309
U rx = p[i * 3];
254310
U ry = p[i * 3 + 1];
255311
U rz = p[i * 3 + 2];
@@ -277,26 +333,31 @@ def error_counting_kernel(resolution, width, height, sensor_noise_factor,
277333
atomicAdd(&newmap[get_map_idx(idx, 3)], 1.0);
278334
}
279335
atomicAdd(&newmap[get_map_idx(idx, 4)], 1.0);
280-
''').substitute(mahalanobis_thresh=mahalanobis_thresh,
281-
outlier_variance=outlier_variance,
282-
traversability_inlier=traversability_inlier),
283-
name='error_counting_kernel')
336+
"""
337+
).substitute(
338+
mahalanobis_thresh=mahalanobis_thresh,
339+
outlier_variance=outlier_variance,
340+
traversability_inlier=traversability_inlier,
341+
),
342+
name="error_counting_kernel",
343+
)
284344
return error_counting_kernel
285345

286346

287347
def average_map_kernel(width, height, max_variance, initial_variance):
288348
average_map_kernel = cp.ElementwiseKernel(
289-
in_params='raw U newmap',
290-
out_params='raw U map',
291-
preamble=\
292-
string.Template('''
349+
in_params="raw U newmap",
350+
out_params="raw U map",
351+
preamble=string.Template(
352+
"""
293353
__device__ int get_map_idx(int idx, int layer_n) {
294354
const int layer = ${width} * ${height};
295355
return layer * layer_n + idx;
296356
}
297-
''').substitute(width=width, height=height),
298-
operation=\
299-
string.Template('''
357+
"""
358+
).substitute(width=width, height=height),
359+
operation=string.Template(
360+
"""
300361
U h = map[get_map_idx(i, 0)];
301362
U v = map[get_map_idx(i, 1)];
302363
U valid = map[get_map_idx(i, 2)];
@@ -320,18 +381,19 @@ def average_map_kernel(width, height, max_variance, initial_variance):
320381
map[get_map_idx(i, 1)] = ${initial_variance};
321382
map[get_map_idx(i, 2)] = 0;
322383
}
323-
''').substitute(max_variance=max_variance,
324-
initial_variance=initial_variance),
325-
name='average_map_kernel')
384+
"""
385+
).substitute(max_variance=max_variance, initial_variance=initial_variance),
386+
name="average_map_kernel",
387+
)
326388
return average_map_kernel
327389

328390

329391
def dilation_filter_kernel(width, height, dilation_size):
330392
dilation_filter_kernel = cp.ElementwiseKernel(
331-
in_params='raw U map, raw U mask',
332-
out_params='raw U newmap, raw U newmask',
333-
preamble=\
334-
string.Template('''
393+
in_params="raw U map, raw U mask",
394+
out_params="raw U newmap, raw U newmask",
395+
preamble=string.Template(
396+
"""
335397
__device__ int get_map_idx(int idx, int layer_n) {
336398
const int layer = ${width} * ${height};
337399
return layer * layer_n + idx;
@@ -353,9 +415,10 @@ def dilation_filter_kernel(width, height, dilation_size):
353415
}
354416
return true;
355417
}
356-
''').substitute(width=width, height=height),
357-
operation=\
358-
string.Template('''
418+
"""
419+
).substitute(width=width, height=height),
420+
operation=string.Template(
421+
"""
359422
U h = map[get_map_idx(i, 0)];
360423
U valid = mask[get_map_idx(i, 0)];
361424
newmap[get_map_idx(i, 0)] = h;
@@ -378,17 +441,19 @@ def dilation_filter_kernel(width, height, dilation_size):
378441
newmask[get_map_idx(i, 0)] = 1.0;
379442
}
380443
}
381-
''').substitute(dilation_size=dilation_size),
382-
name='dilation_filter_kernel')
444+
"""
445+
).substitute(dilation_size=dilation_size),
446+
name="dilation_filter_kernel",
447+
)
383448
return dilation_filter_kernel
384449

385450

386451
def normal_filter_kernel(width, height, resolution):
387452
normal_filter_kernel = cp.ElementwiseKernel(
388-
in_params='raw U map, raw U mask',
389-
out_params='raw U newmap',
390-
preamble=\
391-
string.Template('''
453+
in_params="raw U map, raw U mask",
454+
out_params="raw U newmap",
455+
preamble=string.Template(
456+
"""
392457
__device__ int get_map_idx(int idx, int layer_n) {
393458
const int layer = ${width} * ${height};
394459
return layer * layer_n + idx;
@@ -413,9 +478,10 @@ def normal_filter_kernel(width, height, resolution):
413478
__device__ float resolution() {
414479
return ${resolution};
415480
}
416-
''').substitute(width=width, height=height, resolution=resolution),
417-
operation=\
418-
string.Template('''
481+
"""
482+
).substitute(width=width, height=height, resolution=resolution),
483+
operation=string.Template(
484+
"""
419485
U h = map[get_map_idx(i, 0)];
420486
U valid = mask[get_map_idx(i, 0)];
421487
if (valid > 0.5) {
@@ -432,17 +498,19 @@ def normal_filter_kernel(width, height, resolution):
432498
newmap[get_map_idx(i, 1)] = ny / norm;
433499
newmap[get_map_idx(i, 2)] = nz / norm;
434500
}
435-
''').substitute(),
436-
name='normal_filter_kernel')
501+
"""
502+
).substitute(),
503+
name="normal_filter_kernel",
504+
)
437505
return normal_filter_kernel
438506

439507

440508
def polygon_mask_kernel(width, height, resolution):
441509
polygon_mask_kernel = cp.ElementwiseKernel(
442-
in_params='raw U polygon, raw U center_x, raw U center_y, raw int16 polygon_n, raw U polygon_bbox',
443-
out_params='raw U mask',
444-
preamble=\
445-
string.Template('''
510+
in_params="raw U polygon, raw U center_x, raw U center_y, raw int16 polygon_n, raw U polygon_bbox",
511+
out_params="raw U mask",
512+
preamble=string.Template(
513+
"""
446514
__device__ struct Point
447515
{
448516
int x;
@@ -533,9 +601,10 @@ def polygon_mask_kernel(width, height, resolution):
533601
return ${width} * idx_x + idx_y;
534602
}
535603
536-
''').substitute(width=width, height=height, resolution=resolution),
537-
operation=\
538-
string.Template('''
604+
"""
605+
).substitute(width=width, height=height, resolution=resolution),
606+
operation=string.Template(
607+
"""
539608
// Point p = {get_idx_x(i, center_x[0]), get_idx_y(i, center_y[0])};
540609
Point p = {get_idx_x(i), get_idx_y(i)};
541610
Point extreme = {100000, p.y};
@@ -577,19 +646,24 @@ def polygon_mask_kernel(width, height, resolution):
577646
if (intersect_cnt % 2 == 0) { mask[i] = 0; }
578647
else { mask[i] = 1; }
579648
}
580-
''').substitute(a=1),
581-
name='polygon_mask_kernel')
649+
"""
650+
).substitute(a=1),
651+
name="polygon_mask_kernel",
652+
)
582653
return polygon_mask_kernel
583654

584655

585-
if __name__ == '__main__':
656+
if __name__ == "__main__":
586657
for i in range(10):
587658
import random
659+
588660
a = cp.zeros((100, 100))
589661
n = random.randint(3, 5)
590662

591663
# polygon = cp.array([[-1, -1], [3, 4], [2, 4], [1, 3]], dtype=float)
592-
polygon = cp.array([[(random.random() - 0.5) * 10, (random.random() - 0.5) * 10] for i in range(n)], dtype=float)
664+
polygon = cp.array(
665+
[[(random.random() - 0.5) * 10, (random.random() - 0.5) * 10] for i in range(n)], dtype=float
666+
)
593667
print(polygon)
594668
polygon_min = polygon.min(axis=0)
595669
polygon_max = polygon.max(axis=0)
@@ -599,10 +673,12 @@ def polygon_mask_kernel(width, height, resolution):
599673
# polygon_bbox = cp.array([-5, -5, 5, 5], dtype=float)
600674
polygon_mask = polygon_mask_kernel(100, 100, 0.1)
601675
import time
676+
602677
start = time.time()
603-
polygon_mask(polygon, 0.0, 0.0, polygon_n, polygon_bbox, a, size=(100*100))
678+
polygon_mask(polygon, 0.0, 0.0, polygon_n, polygon_bbox, a, size=(100 * 100))
604679
print(time.time() - start)
605680
import pylab as plt
681+
606682
print(a)
607683
plt.imshow(cp.asnumpy(a))
608684
plt.show()

0 commit comments

Comments
 (0)