diff --git a/src/amulet/test_utils/test_utils.hpp b/src/amulet/test_utils/test_utils.hpp index 61b5c56..68ed922 100644 --- a/src/amulet/test_utils/test_utils.hpp +++ b/src/amulet/test_utils/test_utils.hpp @@ -14,7 +14,7 @@ std::string cast_to_string(const T& obj) } } -#define _ASSERT_COMPARE(CLS, A, B, OP) \ +#define _ASSERT_COMPARE_2(CLS, A, B, OP_FUNC, OP) \ { \ CLS assert_comp_a = [&]() { \ try { \ @@ -48,10 +48,10 @@ std::string cast_to_string(const T& obj) throw std::runtime_error(assert_comp_msg); \ } \ }(); \ - if (!(assert_comp_a OP assert_comp_b)) { \ + if (!(OP_FUNC)) { \ std::string assert_comp_msg; \ assert_comp_msg.reserve(200); \ - assert_comp_msg += "A " #OP " B failed in file "; \ + assert_comp_msg += "A " OP " B failed in file "; \ assert_comp_msg += __FILE__; \ assert_comp_msg += " at line "; \ assert_comp_msg += std::to_string(__LINE__); \ @@ -68,6 +68,8 @@ std::string cast_to_string(const T& obj) } \ } +#define _ASSERT_COMPARE(CLS, A, B, OP) _ASSERT_COMPARE_2(CLS, A, B, assert_comp_a OP assert_comp_b, #OP) + #define ASSERT_EQUAL(CLS, A, B) _ASSERT_COMPARE(CLS, A, B, ==) #define ASSERT_NOT_EQUAL(CLS, A, B) _ASSERT_COMPARE(CLS, A, B, !=) #define ASSERT_LESS(CLS, A, B) _ASSERT_COMPARE(CLS, A, B, <) @@ -103,3 +105,8 @@ std::string cast_to_string(const T& obj) throw std::runtime_error(assert_raise_msg); \ } \ } + +#define ASSERT_ALMOST_EQUAL_2(CLS, A, B, ERR) _ASSERT_COMPARE_2(CLS, A, B, std::abs(assert_comp_a - assert_comp_b) <= ERR, "≈") +#define ASSERT_ALMOST_EQUAL(CLS, A, B) ASSERT_ALMOST_EQUAL_2(CLS, A, B, 0.000001) +#define ASSERT_NOT_ALMOST_EQUAL_2(CLS, A, B, ERR) _ASSERT_COMPARE_2(CLS, A, B, std::abs(assert_comp_a - assert_comp_b) > ERR, "≉") +#define ASSERT_NOT_ALMOST_EQUAL(CLS, A, B) ASSERT_NOT_ALMOST_EQUAL_2(CLS, A, B, 0.000001) diff --git a/tests/test_amulet_test_utils/_test_test_utils.py.cpp b/tests/test_amulet_test_utils/_test_test_utils.py.cpp index 8e1df44..fd39924 100644 --- a/tests/test_amulet_test_utils/_test_test_utils.py.cpp +++ b/tests/test_amulet_test_utils/_test_test_utils.py.cpp @@ -73,6 +73,22 @@ static void test_assert_raises_3() { ASSERT_RAISES(std::runtime_error, throw std::invalid_argument("")) } +static void test_assert_almost_equal(double a, double b){ + ASSERT_ALMOST_EQUAL(double, a, b); +} + +static void test_assert_almost_equal_2(double a, double b, double err){ + ASSERT_ALMOST_EQUAL_2(double, a, b, err); +} + +static void test_assert_not_almost_equal(double a, double b){ + ASSERT_NOT_ALMOST_EQUAL(double, a, b); +} + +static void test_assert_not_almost_equal_2(double a, double b, double err){ + ASSERT_NOT_ALMOST_EQUAL_2(double, a, b, err); +} + PYBIND11_MODULE(_test_test_utils, m) { m.def("test_assert_equal_1", &test_assert_equal_1); @@ -89,4 +105,8 @@ PYBIND11_MODULE(_test_test_utils, m) m.def("test_assert_raises_1", &test_assert_raises_1); m.def("test_assert_raises_2", &test_assert_raises_2); m.def("test_assert_raises_3", &test_assert_raises_3); + m.def("test_assert_almost_equal", &test_assert_almost_equal); + m.def("test_assert_almost_equal_2", &test_assert_almost_equal_2); + m.def("test_assert_not_almost_equal", &test_assert_not_almost_equal); + m.def("test_assert_not_almost_equal_2", &test_assert_not_almost_equal_2); } diff --git a/tests/test_amulet_test_utils/_test_test_utils.pyi b/tests/test_amulet_test_utils/_test_test_utils.pyi index 9afc1ed..f056ce1 100644 --- a/tests/test_amulet_test_utils/_test_test_utils.pyi +++ b/tests/test_amulet_test_utils/_test_test_utils.pyi @@ -12,3 +12,7 @@ def test_assert_greater_equal(a: int, b: int) -> None: ... def test_assert_raises_1() -> None: ... def test_assert_raises_2() -> None: ... def test_assert_raises_3() -> None: ... +def test_assert_almost_equal(a: float, b: float) -> None: ... +def test_assert_almost_equal_2(a: float, b: float, err: float) -> None: ... +def test_assert_not_almost_equal(a: float, b: float) -> None: ... +def test_assert_not_almost_equal_2(a: float, b: float, err: float) -> None: ... diff --git a/tests/test_amulet_test_utils/test_test_utils.py b/tests/test_amulet_test_utils/test_test_utils.py index 0a7eb5f..214bcc5 100644 --- a/tests/test_amulet_test_utils/test_test_utils.py +++ b/tests/test_amulet_test_utils/test_test_utils.py @@ -68,3 +68,38 @@ def test_assert_raises(self) -> None: test_assert_raises_2() with self.assertRaises(RuntimeError): test_assert_raises_3() + + def test_almost_equal(self) -> None: + from test_amulet_test_utils._test_test_utils import ( + test_assert_almost_equal, + test_assert_almost_equal_2, + test_assert_not_almost_equal, + test_assert_not_almost_equal_2 + ) + test_assert_almost_equal(5.5, 5.50000001) + test_assert_almost_equal(5.50000001, 5.5) + with self.assertRaises(RuntimeError): + test_assert_almost_equal(5.5, 5.50001) + with self.assertRaises(RuntimeError): + test_assert_almost_equal(5.50001, 5.5) + + test_assert_almost_equal_2(5.5, 5.501, 0.01) + test_assert_almost_equal_2(5.501, 5.5, 0.01) + with self.assertRaises(RuntimeError): + test_assert_almost_equal_2(5.5, 5.52, 0.01) + with self.assertRaises(RuntimeError): + test_assert_almost_equal_2(5.52, 5.5, 0.01) + + test_assert_not_almost_equal(5.5, 5.501) + test_assert_not_almost_equal(5.501, 5.5) + with self.assertRaises(RuntimeError): + test_assert_not_almost_equal(5.5, 5.50000001) + with self.assertRaises(RuntimeError): + test_assert_not_almost_equal(5.50000001, 5.5) + + test_assert_not_almost_equal_2(5.5, 6.6, 1) + test_assert_not_almost_equal_2(6.6, 5.5, 1) + with self.assertRaises(RuntimeError): + test_assert_not_almost_equal_2(5.5, 5.6, 1) + with self.assertRaises(RuntimeError): + test_assert_not_almost_equal_2(5.6, 5.5, 1)