99import inspect
1010from dataclasses import dataclass
1111from ruamel .yaml import YAML
12+ from inspect import signature
1213
1314
1415@dataclass
@@ -38,6 +39,7 @@ def __call__(
3839 layer_names : List [str ],
3940 plugin_layers : cp .ndarray ,
4041 plugin_layer_names : List [str ],
42+ ** kwargs ,
4143 ) -> cp .ndarray :
4244 """This gets the elevation map data and plugin layers as a cupy array.
4345
@@ -75,9 +77,15 @@ def init(self, plugin_params: List[PluginParams], extra_params: List[Dict]):
7577
7678 self .plugins = []
7779 for param , extra_param in zip (plugin_params , extra_params ):
78- m = importlib .import_module ("." + param .name , package = "elevation_mapping_cupy.plugins" ) # -> 'module'
80+ m = importlib .import_module (
81+ "." + param .name , package = "elevation_mapping_cupy.plugins"
82+ ) # -> 'module'
7983 for name , obj in inspect .getmembers (m ):
80- if inspect .isclass (obj ) and issubclass (obj , PluginBase ) and name != "PluginBase" :
84+ if (
85+ inspect .isclass (obj )
86+ # and issubclass(obj, PluginBase)
87+ and name != "PluginBase"
88+ ):
8189 # Add cell_n to params
8290 extra_param ["cell_n" ] = self .cell_n
8391 self .plugins .append (obj (** extra_param ))
@@ -102,6 +110,7 @@ def load_plugin_settings(self, file_path: str):
102110 )
103111 )
104112 extra_params .append (v ["extra_params" ])
113+ print (plugin_params )
105114 self .init (plugin_params , extra_params )
106115 print ("Loaded plugins are " , * self .plugin_names )
107116
@@ -133,10 +142,38 @@ def get_layer_index_with_name(self, name: str) -> int:
133142 print ("Error with layer {}: {}" .format (name , e ))
134143 return None
135144
136- def update_with_name (self , name : str , elevation_map : cp .ndarray , layer_names : List [str ]):
145+ def update_with_name (
146+ self ,
147+ name : str ,
148+ elevation_map : cp .ndarray ,
149+ layer_names : List [str ],
150+ semantic_map = None ,
151+ transform = None ,
152+ ):
137153 idx = self .get_layer_index_with_name (name )
138154 if idx is not None :
139- self .layers [idx ] = self .plugins [idx ](elevation_map , layer_names , self .layers , self .layer_names )
155+ n_param = len (signature (self .plugins [idx ]).parameters )
156+ if n_param == 5 :
157+ self .layers [idx ] = self .plugins [idx ](
158+ elevation_map , layer_names , self .layers , self .layer_names
159+ )
160+ elif n_param == 6 :
161+ self .layers [idx ] = self .plugins [idx ](
162+ elevation_map ,
163+ layer_names ,
164+ self .layers ,
165+ self .layer_names ,
166+ semantic_map ,
167+ )
168+ else :
169+ self .layers [idx ] = self .plugins [idx ](
170+ elevation_map ,
171+ layer_names ,
172+ self .layers ,
173+ self .layer_names ,
174+ semantic_map ,
175+ transform ,
176+ )
140177
141178 def get_map_with_name (self , name : str ) -> cp .ndarray :
142179 idx = self .get_layer_index_with_name (name )
@@ -154,17 +191,28 @@ def get_param_with_name(self, name: str) -> PluginParams:
154191 PluginParams (name = "min_filter" , layer_name = "min_filter" ),
155192 PluginParams (name = "smooth_filter" , layer_name = "smooth" ),
156193 ]
157- extra_params = [{"dilation_size" : 5 , "iteration_n" : 5 }, {"input_layer_name" : "elevation2" }]
194+ extra_params = [
195+ {"dilation_size" : 5 , "iteration_n" : 5 },
196+ {"input_layer_name" : "elevation2" },
197+ ]
158198 manager = PluginManager (200 )
159- manager .load_plugin_settings ("config/plugin_config.yaml" )
199+ manager .load_plugin_settings ("../ config/plugin_config.yaml" )
160200 print (manager .layer_names )
161201 print (manager .plugin_names )
162202 elevation_map = cp .zeros ((7 , 200 , 200 )).astype (cp .float32 )
163- layer_names = ["elevation" , "variance" , "is_valid" , "traversability" , "time" , "upper_bound" , "is_upper_bound" ]
203+ layer_names = [
204+ "elevation" ,
205+ "variance" ,
206+ "is_valid" ,
207+ "traversability" ,
208+ "time" ,
209+ "upper_bound" ,
210+ "is_upper_bound" ,
211+ ]
164212 elevation_map [0 ] = cp .random .randn (200 , 200 )
165213 elevation_map [2 ] = cp .abs (cp .random .randn (200 , 200 ))
166214 print ("map" , elevation_map [0 ])
167215 print ("layer map " , manager .layers [0 ])
168216 manager .update_with_name ("min_filter" , elevation_map , layer_names )
169- manager .update_with_name ("smooth_filter " , elevation_map , layer_names )
217+ manager .update_with_name ("smooth " , elevation_map , layer_names )
170218 print (manager .get_map_with_name ("smooth" ))
0 commit comments