Skip to content

Commit c1637f7

Browse files
authored
Merge pull request #134 from leggedrobotics/feature/masked-replace-enhancements
Add masked_replace enhancements from dev/ros2/update_map_rebase
2 parents 95b7c2c + f775d8c commit c1637f7

File tree

2 files changed

+155
-24
lines changed

2 files changed

+155
-24
lines changed

elevation_mapping_cupy/elevation_mapping_cupy/elevation_mapping.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,12 +1099,13 @@ def apply_masked_replace(
10991099
if np.any(valid_mask):
11001100
vals = incoming_slice[valid_mask]
11011101
min_max = (float(np.nanmin(vals)), float(np.nanmax(vals)))
1102-
map_extent = self._map_extent_from_slices(map_rows, map_cols)
1102+
map_extent = self._map_extent_from_mask(map_rows, map_cols, valid_mask) or self._map_extent_from_slices(map_rows, map_cols)
11031103
print(
11041104
f"[ElevationMap] masked_replace layer '{name}': wrote {written} cells, "
11051105
f"X∈[{map_extent['x_min']:.2f},{map_extent['x_max']:.2f}], "
11061106
f"Y∈[{map_extent['y_min']:.2f},{map_extent['y_max']:.2f}], "
1107-
f"values {min_max if min_max else 'n/a'}"
1107+
f"values {min_max if min_max else 'n/a'}",
1108+
flush=True
11081109
)
11091110

11101111
self._invalidate_caches()
@@ -1252,6 +1253,27 @@ def _map_extent_from_slices(self, rows: slice, cols: slice) -> Dict[str, float]:
12521253
y_max = map_min_y + (rows.stop - 0.5) * self.resolution
12531254
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
12541255

1256+
def _map_extent_from_mask(self, rows: slice, cols: slice, valid_mask: np.ndarray) -> Optional[Dict[str, float]]:
1257+
"""Compute extent based on the actual mask footprint; returns None if mask is empty."""
1258+
if valid_mask is None or not np.any(valid_mask):
1259+
return None
1260+
row_idx, col_idx = np.nonzero(valid_mask)
1261+
row_min = rows.start + int(row_idx.min())
1262+
row_max = rows.start + int(row_idx.max())
1263+
col_min = cols.start + int(col_idx.min())
1264+
col_max = cols.start + int(col_idx.max())
1265+
1266+
map_length = (self.cell_n - 2) * self.resolution
1267+
center_cpu = np.asarray(cp.asnumpy(self.center))
1268+
map_min_x = center_cpu[0] - map_length / 2.0
1269+
map_min_y = center_cpu[1] - map_length / 2.0
1270+
1271+
x_min = map_min_x + (col_min + 0.5) * self.resolution
1272+
x_max = map_min_x + (col_max + 0.5) * self.resolution
1273+
y_min = map_min_y + (row_min + 0.5) * self.resolution
1274+
y_max = map_min_y + (row_max + 0.5) * self.resolution
1275+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
1276+
12551277
def _invalidate_caches(self, reset_plugins: bool = True):
12561278
self.traversability_buffer[...] = cp.nan
12571279
if reset_plugins:

scripts/masked_replace_tool.py

Lines changed: 131 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ def build_parser() -> argparse.ArgumentParser:
4646
parser.add_argument("--center-z", type=float, default=0.0, help="Patch center Z coordinate (meters).")
4747
parser.add_argument("--size-x", type=positive_float, default=1.0, help="Patch length in X (meters).")
4848
parser.add_argument("--size-y", type=positive_float, default=1.0, help="Patch length in Y (meters).")
49+
parser.add_argument(
50+
"--full-length-x",
51+
type=positive_float,
52+
default=None,
53+
help="Optional total GridMap length in X (meters). If set, a full-size map is sent and only the patch region is marked in the mask."
54+
)
55+
parser.add_argument(
56+
"--full-length-y",
57+
type=positive_float,
58+
default=None,
59+
help="Optional total GridMap length in Y (meters). If set, a full-size map is sent and only the patch region is marked in the mask."
60+
)
61+
parser.add_argument(
62+
"--full-center-x",
63+
type=float,
64+
default=0.0,
65+
help="GridMap center X (meters) to use when sending a full-size map. Defaults to 0."
66+
)
67+
parser.add_argument(
68+
"--full-center-y",
69+
type=float,
70+
default=0.0,
71+
help="GridMap center Y (meters) to use when sending a full-size map. Defaults to 0."
72+
)
4973
parser.add_argument("--resolution", type=positive_float, default=0.1, help="Grid resolution (meters per cell).")
5074
parser.add_argument("--elevation", type=float, default=0.1, help="Elevation value to set (meters).")
5175
parser.add_argument("--variance", type=non_negative_float, default=0.05, help="Variance value to set.")
@@ -84,6 +108,10 @@ class PatchConfig:
84108
mask_value: float
85109
add_valid_layer: bool
86110
invalidate_first: bool
111+
full_length_x: Optional[float] = None
112+
full_length_y: Optional[float] = None
113+
full_center_x: float = 0.0
114+
full_center_y: float = 0.0
87115

88116
@property
89117
def shape(self) -> Dict[str, int]:
@@ -139,10 +167,17 @@ def _base_grid_map(self) -> GridMap:
139167
gm.header.frame_id = cfg.frame_id
140168
gm.header.stamp = self.get_clock().now().to_msg()
141169
gm.info.resolution = cfg.resolution
142-
gm.info.length_x = cfg.actual_length_x
143-
gm.info.length_y = cfg.actual_length_y
144-
gm.info.pose.position.x = cfg.center_x
145-
gm.info.pose.position.y = cfg.center_y
170+
# If full map was requested, use the full lengths and center the GridMap at the full-map center.
171+
if cfg.full_length_x or cfg.full_length_y:
172+
gm.info.length_x = cfg.full_length_x or cfg.actual_length_x
173+
gm.info.length_y = cfg.full_length_y or cfg.actual_length_y
174+
gm.info.pose.position.x = cfg.full_center_x
175+
gm.info.pose.position.y = cfg.full_center_y
176+
else:
177+
gm.info.length_x = cfg.actual_length_x
178+
gm.info.length_y = cfg.actual_length_y
179+
gm.info.pose.position.x = cfg.center_x
180+
gm.info.pose.position.y = cfg.center_y
146181
gm.info.pose.position.z = cfg.center_z
147182
gm.info.pose.orientation.w = 1.0
148183
gm.basic_layers = ["elevation"]
@@ -157,32 +192,102 @@ def _mask_array(self, force_value: Optional[float] = None) -> np.ndarray:
157192
mask_value = 1.0
158193
return np.full((rows, cols), mask_value, dtype=np.float32)
159194

195+
def _make_full_arrays(self) -> Dict[str, np.ndarray]:
196+
"""Create full-size arrays (possibly larger than the patch) and place the patch in them."""
197+
cfg = self._config
198+
length_x = cfg.full_length_x or cfg.length_x
199+
length_y = cfg.full_length_y or cfg.length_y
200+
cols_full = max(1, ceil(length_x / cfg.resolution))
201+
rows_full = max(1, ceil(length_y / cfg.resolution))
202+
203+
# Base arrays filled with NaN (masked out)
204+
mask_full = np.full((rows_full, cols_full), np.nan, dtype=np.float32)
205+
elev_full = np.full((rows_full, cols_full), np.nan, dtype=np.float32)
206+
var_full = np.full((rows_full, cols_full), np.nan, dtype=np.float32)
207+
valid_full = np.zeros((rows_full, cols_full), dtype=np.float32)
208+
209+
# Patch dimensions and offset within the full map
210+
patch_rows = cfg.shape["rows"]
211+
patch_cols = cfg.shape["cols"]
212+
row_offset = int(round(cfg.center_y / cfg.resolution))
213+
col_offset = int(round(cfg.center_x / cfg.resolution))
214+
row_start = rows_full // 2 + row_offset - patch_rows // 2
215+
col_start = cols_full // 2 + col_offset - patch_cols // 2
216+
row_end = row_start + patch_rows
217+
col_end = col_start + patch_cols
218+
219+
# Clamp if window would exceed bounds
220+
if row_start < 0 or col_start < 0 or row_end > rows_full or col_end > cols_full:
221+
raise ValueError("Patch exceeds full map bounds; adjust center/size or full map length.")
222+
223+
mask_val = cfg.mask_value
224+
if np.isnan(mask_val):
225+
mask_val = 1.0
226+
mask_full[row_start:row_end, col_start:col_end] = mask_val
227+
elev_full[row_start:row_end, col_start:col_end] = cfg.elevation
228+
var_full[row_start:row_end, col_start:col_end] = cfg.variance
229+
if cfg.add_valid_layer:
230+
valid_full[row_start:row_end, col_start:col_end] = 1.0
231+
232+
return {
233+
"mask": mask_full,
234+
"elevation": elev_full,
235+
"variance": var_full,
236+
"is_valid": valid_full,
237+
"rows_full": rows_full,
238+
"cols_full": cols_full,
239+
}
240+
160241
def _build_validity_message(self, value: float) -> GridMap:
161242
gm = self._base_grid_map()
162-
mask = self._mask_array()
163-
rows, cols = mask.shape
164-
gm.layers = [self._config.mask_layer, "is_valid"]
165-
arrays = {
166-
self._config.mask_layer: mask,
167-
"is_valid": np.full((rows, cols), value, dtype=np.float32),
168-
}
243+
cfg = self._config
244+
if cfg.full_length_x or cfg.full_length_y:
245+
arrays_full = self._make_full_arrays()
246+
rows_full = arrays_full["rows_full"]
247+
cols_full = arrays_full["cols_full"]
248+
gm.layers = [cfg.mask_layer, "is_valid"]
249+
arrays = {
250+
cfg.mask_layer: arrays_full["mask"],
251+
"is_valid": np.full((rows_full, cols_full), value, dtype=np.float32),
252+
}
253+
else:
254+
mask = self._mask_array()
255+
rows, cols = mask.shape
256+
gm.layers = [cfg.mask_layer, "is_valid"]
257+
arrays = {
258+
cfg.mask_layer: mask,
259+
"is_valid": np.full((rows, cols), value, dtype=np.float32),
260+
}
169261
for layer in gm.layers:
170262
gm.data.append(self._numpy_to_multiarray(arrays[layer]))
171263
return gm
172264

173265
def _build_data_message(self, valid_value: Optional[float]) -> GridMap:
174266
gm = self._base_grid_map()
175-
mask = self._mask_array()
176-
rows, cols = mask.shape
177-
gm.layers = [self._config.mask_layer, "elevation", "variance"]
178-
arrays = {
179-
self._config.mask_layer: mask,
180-
"elevation": np.full((rows, cols), self._config.elevation, dtype=np.float32),
181-
"variance": np.full((rows, cols), self._config.variance, dtype=np.float32),
182-
}
183-
if valid_value is not None:
184-
gm.layers.append("is_valid")
185-
arrays["is_valid"] = np.full((rows, cols), valid_value, dtype=np.float32)
267+
cfg = self._config
268+
if cfg.full_length_x or cfg.full_length_y:
269+
arrays_full = self._make_full_arrays()
270+
gm.layers = [cfg.mask_layer, "elevation", "variance"]
271+
arrays = {
272+
cfg.mask_layer: arrays_full["mask"],
273+
"elevation": arrays_full["elevation"],
274+
"variance": arrays_full["variance"],
275+
}
276+
if valid_value is not None:
277+
gm.layers.append("is_valid")
278+
arrays["is_valid"] = arrays_full["is_valid"]
279+
else:
280+
mask = self._mask_array()
281+
rows, cols = mask.shape
282+
gm.layers = [cfg.mask_layer, "elevation", "variance"]
283+
arrays = {
284+
cfg.mask_layer: mask,
285+
"elevation": np.full((rows, cols), cfg.elevation, dtype=np.float32),
286+
"variance": np.full((rows, cols), cfg.variance, dtype=np.float32),
287+
}
288+
if valid_value is not None:
289+
gm.layers.append("is_valid")
290+
arrays["is_valid"] = np.full((rows, cols), valid_value, dtype=np.float32)
186291
for layer in gm.layers:
187292
gm.data.append(self._numpy_to_multiarray(arrays[layer]))
188293
return gm
@@ -216,6 +321,10 @@ def main() -> None:
216321
mask_value=args.mask_value,
217322
add_valid_layer=args.valid_layer,
218323
invalidate_first=args.invalidate_first,
324+
full_length_x=args.full_length_x,
325+
full_length_y=args.full_length_y,
326+
full_center_x=args.full_center_x,
327+
full_center_y=args.full_center_y,
219328
)
220329

221330
rclpy.init()

0 commit comments

Comments
 (0)