1+ #
2+ # Copyright (c) 2022, Takahiro Miki. All rights reserved.
3+ # Licensed under the MIT license. See LICENSE file in the project root for details.
4+ #
5+ import cupy as cp
6+ import string
7+ from typing import List
8+
9+ from .plugin_manager import PluginBase
10+
11+
12+ class MaxFilter (PluginBase ):
13+ """This is a filter to fill in invalid cells with maximum values around.
14+
15+ ...
16+
17+ Attributes
18+ ----------
19+ width: int
20+ width of the elevation map.
21+ height: int
22+ height of the elevation map.
23+ dilation_size: int
24+ The size of the patch to search for maximum value for each iteration.
25+ iteration_n: int
26+ The number of iteration to repeat the same filter.
27+ """
28+
29+ def __init__ (self , cell_n : int = 100 , dilation_size : int = 5 , iteration_n : int = 5 , ** kwargs ):
30+ super ().__init__ ()
31+ self .iteration_n = iteration_n
32+ self .width = cell_n
33+ self .height = cell_n
34+ self .max_filtered = cp .zeros ((self .width , self .height ))
35+ self .max_filtered_mask = cp .zeros ((self .width , self .height ))
36+ self .max_filter_kernel = cp .ElementwiseKernel (
37+ in_params = "raw U map, raw U mask" ,
38+ out_params = "raw U newmap, raw U newmask" ,
39+ preamble = string .Template (
40+ """
41+ __device__ int get_map_idx(int idx, int layer_n) {
42+ const int layer = ${width} * ${height};
43+ return layer * layer_n + idx;
44+ }
45+
46+ __device__ int get_relative_map_idx(int idx, int dx, int dy, int layer_n) {
47+ const int layer = ${width} * ${height};
48+ const int relative_idx = idx + ${width} * dy + dx;
49+ return layer * layer_n + relative_idx;
50+ }
51+ __device__ bool is_inside(int idx) {
52+ int idx_x = idx / ${width};
53+ int idx_y = idx % ${width};
54+ if (idx_x <= 0 || idx_x >= ${width} - 1) {
55+ return false;
56+ }
57+ if (idx_y <= 0 || idx_y >= ${height} - 1) {
58+ return false;
59+ }
60+ return true;
61+ }
62+ """
63+ ).substitute (width = self .width , height = self .height ),
64+ operation = string .Template (
65+ """
66+ U valid = mask[get_map_idx(i, 0)];
67+ if (valid < 0.5) {
68+ U max_value = -1000000.0;
69+ for (int dy = -${dilation_size}; dy <= ${dilation_size}; dy++) {
70+ for (int dx = -${dilation_size}; dx <= ${dilation_size}; dx++) {
71+ int idx = get_relative_map_idx(i, dx, dy, 0);
72+ if (!is_inside(idx)) {continue;}
73+ U valid = mask[idx];
74+ U value = map[idx];
75+ if(valid > 0.5 && value > max_value) {
76+ max_value = value;
77+ }
78+ }
79+ }
80+ if (max_value > -1000000 + 1) {
81+ newmap[get_map_idx(i, 0)] = max_value;
82+ newmask[get_map_idx(i, 0)] = 0.6;
83+ }
84+ }
85+ """
86+ ).substitute (dilation_size = dilation_size ),
87+ name = "max_filter_kernel" ,
88+ )
89+
90+ def __call__ (
91+ self ,
92+ elevation_map : cp .ndarray ,
93+ layer_names : List [str ],
94+ plugin_layers : cp .ndarray ,
95+ plugin_layer_names : List [str ],
96+ ) -> cp .ndarray :
97+ self .max_filtered = elevation_map [0 ].copy ()
98+ self .max_filtered_mask = elevation_map [2 ].copy ()
99+ for i in range (self .iteration_n ):
100+ self .max_filter_kernel (
101+ self .max_filtered .copy (),
102+ self .max_filtered_mask .copy (),
103+ self .max_filtered ,
104+ self .max_filtered_mask ,
105+ size = (self .width * self .height ),
106+ )
107+ # If there's no more mask, break
108+ if (self .max_filtered_mask > 0.5 ).all ():
109+ break
110+ max_filtered = cp .where (self .max_filtered_mask > 0.5 , self .max_filtered .copy (), cp .nan )
111+ return max_filtered
0 commit comments