|
| 1 | +# standard library |
| 2 | +from typing import Any |
| 3 | + |
| 4 | + |
| 5 | +# dependencies |
| 6 | +from astropy.units import Quantity |
| 7 | +from pytest import mark, raises |
| 8 | +from xarray import DataArray |
| 9 | +from xarray.testing import assert_identical # type: ignore |
| 10 | +from xarray_units import operator as opr |
| 11 | +from xarray_units.operator import Operator, take |
| 12 | +from xarray_units.quantity import set |
| 13 | +from xarray_units.utils import UnitsApplicationError |
| 14 | + |
| 15 | + |
| 16 | +# test data |
| 17 | +km = set(DataArray([1, 2, 3]), "km") |
| 18 | +mm = set(DataArray([1, 2, 3]) * 1e6, "mm") |
| 19 | +sc_1 = 2 |
| 20 | +sc_2 = Quantity(2000, "m") |
| 21 | + |
| 22 | + |
| 23 | +data_take: list[tuple[DataArray, Operator, Any, Any]] = [ |
| 24 | + (km, "mul", sc_1, set(DataArray([2, 4, 6]), "km")), |
| 25 | + (km, "mul", sc_2, set(DataArray([2, 4, 6]) * 1e3, "km m")), |
| 26 | + (km, "mul", mm, set(DataArray([1, 4, 9]) * 1e6, "km mm")), |
| 27 | + (mm, "mul", km, set(DataArray([1, 4, 9]) * 1e6, "km mm")), |
| 28 | + # |
| 29 | + (km, "pow", sc_1, set(DataArray([1, 4, 9]), "km2")), |
| 30 | + (km, "pow", sc_2, UnitsApplicationError), |
| 31 | + (km, "pow", mm, UnitsApplicationError), |
| 32 | + (mm, "pow", km, UnitsApplicationError), |
| 33 | + # |
| 34 | + (km, "matmul", sc_1, UnitsApplicationError), |
| 35 | + (km, "matmul", sc_2, UnitsApplicationError), |
| 36 | + (km, "matmul", mm, set(DataArray(14) * 1e6, "km mm")), |
| 37 | + (mm, "matmul", km, set(DataArray(14) * 1e6, "km mm")), |
| 38 | + # |
| 39 | + (km, "truediv", sc_1, set(DataArray([0.5, 1, 1.5]), "km")), |
| 40 | + (km, "truediv", sc_2, set(DataArray([0.5, 1.0, 1.5]) * 1e-3, "km m-1")), |
| 41 | + (km, "truediv", mm, set(DataArray([1, 1, 1]) * 1e-6, "km mm-1")), |
| 42 | + (mm, "truediv", km, set(DataArray([1, 1, 1]) * 1e6, "mm km-1")), |
| 43 | + # |
| 44 | + (km, "add", sc_1, UnitsApplicationError), |
| 45 | + (km, "add", sc_2, set(DataArray([3, 4, 5]), "km")), |
| 46 | + (km, "add", mm, set(DataArray([2, 4, 6]), "km")), |
| 47 | + (mm, "add", km, set(DataArray([2, 4, 6]) * 1e6, "mm")), |
| 48 | + # |
| 49 | + (km, "sub", sc_1, UnitsApplicationError), |
| 50 | + (km, "sub", sc_2, set(DataArray([-1, 0, 1]), "km")), |
| 51 | + (km, "sub", mm, set(DataArray([0, 0, 0]), "km")), |
| 52 | + (mm, "sub", km, set(DataArray([0, 0, 0]), "mm")), |
| 53 | + # |
| 54 | + (km, "floordiv", sc_1, UnitsApplicationError), |
| 55 | + (km, "floordiv", sc_2, set(DataArray([0, 1, 1]), "1")), |
| 56 | + (km, "floordiv", mm, set(DataArray([1, 1, 1]), "1")), |
| 57 | + (mm, "floordiv", km, set(DataArray([1, 1, 1]), "1")), |
| 58 | + # |
| 59 | + (km, "mod", sc_1, UnitsApplicationError), |
| 60 | + (km, "mod", sc_2, set(DataArray([1, 0, 1]), "km")), |
| 61 | + (km, "mod", mm, set(DataArray([0, 0, 0]), "km")), |
| 62 | + (mm, "mod", km, set(DataArray([0, 0, 0]), "mm")), |
| 63 | + # |
| 64 | + (km, "lt", sc_1, UnitsApplicationError), |
| 65 | + (km, "lt", sc_2, DataArray([True, False, False])), |
| 66 | + (km, "lt", mm, DataArray([False, False, False])), |
| 67 | + (mm, "lt", km, DataArray([False, False, False])), |
| 68 | + # |
| 69 | + (km, "le", sc_1, UnitsApplicationError), |
| 70 | + (km, "le", sc_2, DataArray([True, True, False])), |
| 71 | + (km, "le", mm, DataArray([True, True, True])), |
| 72 | + (mm, "le", km, DataArray([True, True, True])), |
| 73 | + # |
| 74 | + (km, "eq", sc_1, UnitsApplicationError), |
| 75 | + (km, "eq", sc_2, DataArray([False, True, False])), |
| 76 | + (km, "eq", mm, DataArray([True, True, True])), |
| 77 | + (mm, "eq", km, DataArray([True, True, True])), |
| 78 | + # |
| 79 | + (km, "ne", sc_1, UnitsApplicationError), |
| 80 | + (km, "ne", sc_2, DataArray([True, False, True])), |
| 81 | + (km, "ne", mm, DataArray([False, False, False])), |
| 82 | + (mm, "ne", km, DataArray([False, False, False])), |
| 83 | + # |
| 84 | + (km, "ge", sc_1, UnitsApplicationError), |
| 85 | + (km, "ge", sc_2, DataArray([False, True, True])), |
| 86 | + (km, "ge", mm, DataArray([True, True, True])), |
| 87 | + (mm, "ge", km, DataArray([True, True, True])), |
| 88 | + # |
| 89 | + (km, "gt", sc_1, UnitsApplicationError), |
| 90 | + (km, "gt", sc_2, DataArray([False, False, True])), |
| 91 | + (km, "gt", mm, DataArray([False, False, False])), |
| 92 | + (mm, "gt", km, DataArray([False, False, False])), |
| 93 | +] |
| 94 | + |
| 95 | + |
| 96 | +@mark.parametrize("left, operator, right, expected", data_take) |
| 97 | +def test_take( |
| 98 | + left: DataArray, |
| 99 | + operator: Operator, |
| 100 | + right: Any, |
| 101 | + expected: Any, |
| 102 | +) -> None: |
| 103 | + if expected is UnitsApplicationError: |
| 104 | + with raises(expected): |
| 105 | + take(left, operator, right) |
| 106 | + else: |
| 107 | + assert_identical(take(left, operator, right), expected) |
| 108 | + |
| 109 | + |
| 110 | +@mark.parametrize("left, operator, right, expected", data_take) |
| 111 | +def test_take_alias( |
| 112 | + left: DataArray, |
| 113 | + operator: Operator, |
| 114 | + right: Any, |
| 115 | + expected: DataArray, |
| 116 | +) -> None: |
| 117 | + if expected is UnitsApplicationError: |
| 118 | + with raises(expected): |
| 119 | + getattr(opr, operator)(left, right) |
| 120 | + else: |
| 121 | + assert_identical(getattr(opr, operator)(left, right), expected) |
0 commit comments