|
| 1 | +import sys |
| 2 | +from pathlib import Path |
| 3 | +from types import ModuleType, SimpleNamespace |
| 4 | +from unittest import mock |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import pytest |
| 8 | + |
| 9 | +sys.modules.setdefault("cupy", mock.MagicMock()) |
| 10 | + |
| 11 | +if "grid_map_msgs" not in sys.modules: |
| 12 | + grid_map_module = ModuleType("grid_map_msgs") |
| 13 | + grid_map_msg_module = ModuleType("grid_map_msgs.msg") |
| 14 | + grid_map_srv_module = ModuleType("grid_map_msgs.srv") |
| 15 | + |
| 16 | + class _DummyGridMap: |
| 17 | + pass |
| 18 | + |
| 19 | + class _DummySetGridMap: |
| 20 | + pass |
| 21 | + |
| 22 | + class _DummyProcessFile: |
| 23 | + pass |
| 24 | + |
| 25 | + grid_map_msg_module.GridMap = _DummyGridMap |
| 26 | + grid_map_srv_module.SetGridMap = _DummySetGridMap |
| 27 | + grid_map_srv_module.ProcessFile = _DummyProcessFile |
| 28 | + |
| 29 | + grid_map_module.msg = grid_map_msg_module |
| 30 | + grid_map_module.srv = grid_map_srv_module |
| 31 | + |
| 32 | + sys.modules["grid_map_msgs"] = grid_map_module |
| 33 | + sys.modules["grid_map_msgs.msg"] = grid_map_msg_module |
| 34 | + sys.modules["grid_map_msgs.srv"] = grid_map_srv_module |
| 35 | + |
| 36 | +from grid_map_msgs.msg import GridMap |
| 37 | + |
| 38 | +REPO_ROOT = Path(__file__).resolve().parents[1] |
| 39 | +sys.path.insert(0, str(REPO_ROOT / "elevation_mapping_cupy")) |
| 40 | + |
| 41 | +from elevation_mapping_cupy import elevation_mapping_node as node_mod |
| 42 | + |
| 43 | + |
| 44 | +class DummyLogger: |
| 45 | + def __init__(self): |
| 46 | + self.messages = [] |
| 47 | + |
| 48 | + def info(self, msg): |
| 49 | + self.messages.append(("info", msg)) |
| 50 | + |
| 51 | + def error(self, msg): |
| 52 | + self.messages.append(("error", msg)) |
| 53 | + |
| 54 | + |
| 55 | +class DummyClock: |
| 56 | + def now(self): |
| 57 | + return SimpleNamespace(nanoseconds=123) |
| 58 | + |
| 59 | + |
| 60 | +def bind_method(method, instance): |
| 61 | + """Bind a class method to a lightweight instance for unit testing.""" |
| 62 | + return method.__get__(instance, node_mod.ElevationMappingNode) |
| 63 | + |
| 64 | + |
| 65 | +def make_dummy_node_for_masked_replace(): |
| 66 | + dummy = SimpleNamespace() |
| 67 | + dummy.masked_replace_mask_layer_name = "mask" |
| 68 | + dummy.logger = DummyLogger() |
| 69 | + dummy.get_logger = lambda: dummy.logger |
| 70 | + dummy._map = SimpleNamespace() |
| 71 | + dummy._map.apply_masked_replace = mock.Mock() |
| 72 | + dummy._republish_all_once = mock.Mock() |
| 73 | + return dummy |
| 74 | + |
| 75 | + |
| 76 | +def test_handle_masked_replace_sets_success_flag(): |
| 77 | + dummy = make_dummy_node_for_masked_replace() |
| 78 | + |
| 79 | + size = 2 |
| 80 | + data_layers = { |
| 81 | + "mask": np.ones((size, size), dtype=np.float32), |
| 82 | + "elevation": np.full((size, size), 1.0, dtype=np.float32), |
| 83 | + } |
| 84 | + mask_layer = data_layers["mask"] |
| 85 | + geometry = SimpleNamespace() |
| 86 | + dummy._grid_map_to_numpy = mock.Mock(return_value=(data_layers, geometry)) |
| 87 | + |
| 88 | + request = SimpleNamespace(map=GridMap()) |
| 89 | + response = SimpleNamespace(success=False) |
| 90 | + |
| 91 | + handler = bind_method(node_mod.ElevationMappingNode.handle_masked_replace, dummy) |
| 92 | + handler(request, response) |
| 93 | + |
| 94 | + assert response.success is True |
| 95 | + dummy._map.apply_masked_replace.assert_called_once() |
| 96 | + call_args = dummy._map.apply_masked_replace.call_args |
| 97 | + layers_arg, mask_arg, geometry_arg = call_args.args |
| 98 | + assert "mask" not in layers_arg |
| 99 | + assert mask_arg is mask_layer |
| 100 | + assert geometry_arg is geometry |
| 101 | + dummy._republish_all_once.assert_called_once() |
| 102 | + |
| 103 | + |
| 104 | +def test_handle_masked_replace_failure_sets_success_false(): |
| 105 | + dummy = make_dummy_node_for_masked_replace() |
| 106 | + dummy._grid_map_to_numpy = mock.Mock(side_effect=RuntimeError("boom")) |
| 107 | + |
| 108 | + request = SimpleNamespace(map=GridMap()) |
| 109 | + response = SimpleNamespace(success=True) |
| 110 | + |
| 111 | + handler = bind_method(node_mod.ElevationMappingNode.handle_masked_replace, dummy) |
| 112 | + handler(request, response) |
| 113 | + |
| 114 | + assert response.success is False |
| 115 | + |
| 116 | + |
| 117 | +def test_write_grid_map_bag_uses_correct_topic_metadata(tmp_path, monkeypatch): |
| 118 | + dummy = SimpleNamespace() |
| 119 | + dummy.save_map_storage_id = "mcap" |
| 120 | + dummy._clock = DummyClock() |
| 121 | + dummy.get_clock = lambda: dummy._clock |
| 122 | + |
| 123 | + storage_mock = mock.Mock(name="StorageOptions") |
| 124 | + converter_mock = mock.Mock(name="ConverterOptions") |
| 125 | + metadata_mock = mock.Mock(name="TopicMetadataReturn") |
| 126 | + writer_mock = mock.Mock(name="SequentialWriter") |
| 127 | + |
| 128 | + storage_cls = mock.Mock(return_value=storage_mock) |
| 129 | + converter_cls = mock.Mock(return_value=converter_mock) |
| 130 | + call_sequence = {"count": 0} |
| 131 | + |
| 132 | + def topic_metadata_side_effect(*args, **kwargs): |
| 133 | + call_sequence["count"] += 1 |
| 134 | + if call_sequence["count"] == 1: |
| 135 | + assert args == (topic, "grid_map_msgs/msg/GridMap", "cdr", "") |
| 136 | + raise TypeError("legacy signature only accepts name first") |
| 137 | + assert args == (0, topic, "grid_map_msgs/msg/GridMap", "cdr") |
| 138 | + return metadata_mock |
| 139 | + |
| 140 | + metadata_cls = mock.Mock(side_effect=topic_metadata_side_effect) |
| 141 | + writer_cls = mock.Mock(return_value=writer_mock) |
| 142 | + |
| 143 | + monkeypatch.setattr(node_mod.rosbag2_py, "StorageOptions", storage_cls) |
| 144 | + monkeypatch.setattr(node_mod.rosbag2_py, "ConverterOptions", converter_cls) |
| 145 | + monkeypatch.setattr(node_mod.rosbag2_py, "TopicMetadata", metadata_cls) |
| 146 | + monkeypatch.setattr(node_mod.rosbag2_py, "SequentialWriter", writer_cls) |
| 147 | + monkeypatch.setattr(node_mod, "serialize_message", lambda msg: b"serialized") |
| 148 | + |
| 149 | + path = tmp_path / "map_bag" |
| 150 | + topic = "/test_topic" |
| 151 | + grid_map_msg = GridMap() |
| 152 | + |
| 153 | + writer = bind_method(node_mod.ElevationMappingNode._write_grid_map_bag, dummy) |
| 154 | + dummy._make_topic_metadata = bind_method(node_mod.ElevationMappingNode._make_topic_metadata, dummy) |
| 155 | + writer(path, topic, grid_map_msg) |
| 156 | + |
| 157 | + storage_cls.assert_called_once_with(uri=str(path), storage_id="mcap") |
| 158 | + converter_cls.assert_called_once_with("", "") |
| 159 | + assert metadata_cls.call_count == 2 |
| 160 | + assert metadata_cls.call_args_list[-1].args == (0, topic, "grid_map_msgs/msg/GridMap", "cdr") |
| 161 | + writer_mock.create_topic.assert_called_once_with(metadata_mock) |
| 162 | + writer_mock.write.assert_called_once() |
0 commit comments