@@ -50,30 +50,65 @@ def test_xarray_quantity(ds_pint):
50
50
sigma0 = gsw_xarray .sigma0 (SA = ds_pint .SA , CT = ds_pint .CT )
51
51
assert sigma0 .pint .units == pint_xarray .unit_registry ("kg / m^3" )
52
52
53
- @pytest .mark .parametrize ("SA_type" , ['unit' , 'ds' ])
54
- @pytest .mark .parametrize ("CT_type" , ['unit' , 'ds' ])
53
+
54
+ @pytest .mark .parametrize ("SA_type" , ["unit" , "ds" ])
55
+ @pytest .mark .parametrize ("CT_type" , ["unit" , "ds" ])
55
56
def test_xarray_quantity_or_ds (ds , ds_pint , SA_type , CT_type ):
56
57
"""If at least 1 of the inputs is quantity, the result should be quantity"""
57
58
pint_xarray = pytest .importorskip ("pint_xarray" )
58
- if SA_type == ' unit' :
59
+ if SA_type == " unit" :
59
60
SA = ds_pint .SA
60
- elif SA_type == 'ds' :
61
+ elif SA_type == "ds" :
61
62
SA = ds .SA
62
-
63
- if CT_type == ' unit' :
63
+
64
+ if CT_type == " unit" :
64
65
CT = ds_pint .CT
65
- elif CT_type == 'ds' :
66
+ elif CT_type == "ds" :
66
67
CT = ds .CT
67
-
68
+
68
69
sigma0 = gsw_xarray .sigma0 (SA = SA , CT = CT )
69
- if SA_type == ' unit' or CT_type == ' unit' :
70
+ if SA_type == " unit" or CT_type == " unit" :
70
71
assert sigma0 .pint .units == pint_xarray .unit_registry ("kg / m^3" )
71
72
else :
72
73
assert sigma0 .pint .units is None
73
- assert sigma0 .pint .quantify ().pint .units == pint_xarray .unit_registry ("kg / m^3" )
74
+ assert sigma0 .pint .quantify ().pint .units == pint_xarray .unit_registry (
75
+ "kg / m^3"
76
+ )
74
77
75
78
76
79
def test_func_return_tuple_quantity (ds_pint ):
77
80
pint_xarray = pytest .importorskip ("pint_xarray" )
78
81
(CT_SA , CT_pt ) = gsw_xarray .CT_first_derivatives (ds_pint .SA , 1 )
79
82
assert CT_SA .pint .units == pint_xarray .unit_registry ("K/(g/kg)" )
83
+
84
+
85
+ def test_pint_quantity_xarray (ds ):
86
+ """If input is mixed between xr.DataArray and pint quantity it should return pint-xarray wrapped quantity"""
87
+ pint_xarray = pytest .importorskip ("pint_xarray" )
88
+ import pint
89
+
90
+ ureg = pint .UnitRegistry ()
91
+ Q_ = ureg .Quantity
92
+ sigma0 = gsw_xarray .sigma0 (SA = ds .SA , CT = Q_ (25.4 , ureg .degC ))
93
+ assert sigma0 .pint .units == pint_xarray .unit_registry ("kg / m^3" )
94
+
95
+
96
+ def test_pint_quantity ():
97
+ """If input is pint quantity should return a quantity"""
98
+ pint_xarray = pytest .importorskip ("pint_xarray" )
99
+ import pint
100
+
101
+ ureg = pint .UnitRegistry ()
102
+ CT = gsw_xarray .CT_from_pt (SA = 35 * ureg ("g / kg" ), pt = 10 )
103
+ assert isinstance (CT , pint .Quantity )
104
+
105
+
106
+ def test_pint_quantity_tuple ():
107
+ """If input is pint quantity should return a quantity"""
108
+ pint_xarray = pytest .importorskip ("pint_xarray" )
109
+ import pint
110
+
111
+ ureg = pint .UnitRegistry ()
112
+ (a , b ) = gsw_xarray .CT_first_derivatives (35 * ureg ("g / kg" ), pt = 1 )
113
+ assert isinstance (a , pint .Quantity )
114
+ assert isinstance (b , pint .Quantity )
0 commit comments