Skip to content

Commit a01e65e

Browse files
committed
Complete and test G3Map methods
1 parent e05e5f7 commit a01e65e

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ add_spt3g_program(bin/gen-analysis-doc)
9999
#Tests
100100
add_spt3g_test(imports)
101101
add_spt3g_test(copycons)
102+
add_spt3g_test(maps)
102103
add_spt3g_test(framepickle)
103104
add_spt3g_test(pipeline)
104105
add_spt3g_test(pipeline_module)

core/include/core/container_pybindings.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,27 @@ register_map(py::module_ &scope, std::string name, Args &&...args)
263263
return it->second;
264264
}, py::return_value_policy::reference_internal);
265265

266+
cls.def("copy", [](const M &m) {
267+
auto v = std::unique_ptr<M>(new M());
268+
for (auto it: m)
269+
v->emplace(it.first, it.second);
270+
return v.release();
271+
});
272+
273+
cls.def("get", [](const M &m, const K &k) {
274+
auto it = m.find(k);
275+
if (it == m.end())
276+
return py::object(py::none());
277+
return py::cast(it->second);
278+
});
279+
280+
cls.def("get", [](const M &m, const K &k, const py::object &d) {
281+
auto it = m.find(k);
282+
if (it == m.end())
283+
return d;
284+
return py::cast(it->second);
285+
});
286+
266287
cls.def("__contains__", [](M &m, const K &k) -> bool {
267288
auto it = m.find(k);
268289
if (it == m.end())
@@ -282,6 +303,38 @@ register_map(py::module_ &scope, std::string name, Args &&...args)
282303
m.erase(it);
283304
});
284305

306+
cls.def("pop", [](M &m, const K &k) {
307+
auto it = m.find(k);
308+
if (it == m.end())
309+
throw py::key_error();
310+
auto v = it->second;
311+
m.erase(it);
312+
return v;
313+
});
314+
315+
cls.def("pop", [](M &m, const K &k, const py::object &d) {
316+
auto it = m.find(k);
317+
if (it == m.end())
318+
return d;
319+
auto v = it->second;
320+
m.erase(it);
321+
return py::cast(v);
322+
});
323+
324+
cls.def("clear", [](M &m) { m.clear(); });
325+
326+
cls.def("update", [](M &m, const py::iterable &v, py::kwargs kw) {
327+
for (auto it: py::dict(v))
328+
m[it.first.cast<K>()] = it.second.cast<V>();
329+
for (auto it: kw)
330+
m[it.first.cast<K>()] = it.second.cast<V>();
331+
});
332+
333+
cls.def("update", [](M &m, py::kwargs kw) {
334+
for (auto it: kw)
335+
m[it.first.cast<K>()] = it.second.cast<V>();
336+
});
337+
285338
// Always use a lambda in case of `using` declaration
286339
cls.def("__len__", [](const M &m) { return m.size(); });
287340

core/tests/maps.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from spt3g.core import G3MapInt
2+
import unittest
3+
4+
class G3MapTestSuite(unittest.TestCase):
5+
def setUp(self):
6+
self.d = {"a": 1, "b": 2, "c": 3}
7+
self.m = G3MapInt(self.d)
8+
9+
def test_getitem(self):
10+
self.assertEqual(self.m["a"], 1)
11+
self.assertEqual(self.m["b"], 2)
12+
self.assertEqual(self.m["c"], 3)
13+
with self.assertRaises(KeyError):
14+
_ = self.m["d"]
15+
16+
def test_setitem(self):
17+
self.m["a"] = 10
18+
self.assertEqual(self.m["a"], 10)
19+
self.m["d"] = 4
20+
self.assertEqual(self.m["d"], 4)
21+
22+
def test_delitem(self):
23+
del self.m["b"]
24+
with self.assertRaises(KeyError):
25+
_ = self.m["b"]
26+
self.assertEqual(len(self.m), 2)
27+
28+
del self.m["a"]
29+
del self.m["c"]
30+
self.assertTrue(not self.m)
31+
32+
with self.assertRaises(KeyError):
33+
del self.m["a"]
34+
35+
def test_contains(self):
36+
self.assertTrue("a" in self.m)
37+
self.assertTrue("b" in self.m)
38+
self.assertTrue("c" in self.m)
39+
self.assertFalse("d" in self.m)
40+
41+
def test_len(self):
42+
self.assertEqual(len(self.m), 3)
43+
self.m["d"] = 4
44+
self.assertEqual(len(self.m), 4)
45+
del self.m["a"]
46+
self.assertEqual(len(self.m), 3)
47+
48+
def test_iter(self):
49+
keys = sorted(list(self.m.keys()))
50+
expected_keys = sorted(list(self.d.keys()))
51+
self.assertEqual(keys, expected_keys)
52+
53+
values = []
54+
for key in self.m:
55+
values.append(self.m[key])
56+
self.assertEqual(sorted(values), sorted(self.d.values()))
57+
58+
def test_keys(self):
59+
keys_view = self.m.keys()
60+
self.assertEqual(sorted(list(keys_view)), sorted(list(self.d.keys())))
61+
self.assertEqual(len(keys_view), len(self.d))
62+
self.m["d"] = 4
63+
self.assertEqual(sorted(list(keys_view)), sorted(list(self.d.keys()) + ["d"]))
64+
65+
def test_values(self):
66+
values_view = self.m.values()
67+
self.assertEqual(sorted(list(values_view)), sorted(list(self.d.values())))
68+
self.assertEqual(len(values_view), len(self.d))
69+
self.m["a"] = 10
70+
values = list(values_view)
71+
values.sort()
72+
expected_values = list(self.d.values())
73+
expected_values[0] = 10
74+
expected_values.sort()
75+
self.assertEqual(sorted(values), sorted(expected_values))
76+
77+
def test_items(self):
78+
items_view = self.m.items()
79+
self.assertEqual(sorted(list(items_view)), sorted(list(self.d.items())))
80+
self.assertEqual(len(items_view), len(self.d))
81+
self.m["d"] = 4
82+
self.assertEqual(sorted(list(items_view)), sorted(list(self.d.items()) + [("d", 4)]))
83+
84+
def test_clear(self):
85+
self.m.clear()
86+
self.assertEqual(len(self.m), 0)
87+
self.assertFalse(self.m)
88+
self.assertEqual(list(self.m.keys()), [])
89+
self.assertEqual(list(self.m.values()), [])
90+
self.assertEqual(list(self.m.items()), [])
91+
92+
def test_copy(self):
93+
copied_dict = self.m.copy()
94+
self.assertEqual(len(copied_dict), len(self.m))
95+
self.assertEqual(sorted(list(copied_dict.items())), sorted(list(self.m.items())))
96+
copied_dict["a"] = 100
97+
self.assertNotEqual(self.m["a"], 100)
98+
self.assertEqual(copied_dict["a"], 100)
99+
100+
def test_get(self):
101+
self.assertEqual(self.m.get("a"), 1)
102+
self.assertEqual(self.m.get("d"), None)
103+
self.assertEqual(self.m.get("d", 5), 5)
104+
self.assertEqual(self.m.get("b", 10), 2)
105+
106+
def test_pop(self):
107+
self.assertEqual(self.m.pop("b"), 2)
108+
self.assertEqual(len(self.m), 2)
109+
with self.assertRaises(KeyError):
110+
_ = self.m["b"]
111+
self.assertEqual(self.m.pop("d", 5), 5)
112+
with self.assertRaises(KeyError):
113+
self.m.pop("d")
114+
115+
def test_update(self):
116+
self.m.update({"b": 20, "d": 4})
117+
self.assertEqual(self.m["b"], 20)
118+
self.assertEqual(self.m["d"], 4)
119+
self.assertEqual(len(self.m), 4)
120+
121+
self.m.update([("c", 30), ("e", 5)])
122+
self.assertEqual(self.m["c"], 30)
123+
self.assertEqual(self.m["e"], 5)
124+
self.assertEqual(len(self.m), 5)
125+
126+
self.m.update(f=6)
127+
self.assertEqual(self.m["f"], 6)
128+
self.assertEqual(len(self.m), 6)
129+
130+
self.m.update({})
131+
self.assertEqual(len(self.m), 6)
132+
133+
self.m.update([])
134+
self.assertEqual(len(self.m), 6)
135+
136+
self.m.update()
137+
self.assertEqual(len(self.m), 6)
138+
139+
def test_bool(self):
140+
self.assertTrue(self.m)
141+
self.m.clear()
142+
self.assertFalse(self.m)
143+
144+
if __name__ == "__main__":
145+
unittest.main()

0 commit comments

Comments
 (0)