12
12
import pint_xarray
13
13
import pint
14
14
15
- ureg = pint .UnitRegistry ()
16
- Q_ = ureg .Quantity
17
15
except ImportError :
18
16
pint_xarray = None
19
17
@@ -24,75 +22,84 @@ def add_attrs(rv, attrs, name):
24
22
rv .attrs = attrs
25
23
26
24
27
- def quantify (rv , attrs ):
25
+ def quantify (rv , attrs , unit_registry = None ):
26
+ if unit_registry is None :
27
+ return rv
28
+
28
29
if isinstance (rv , xr .DataArray ):
29
- rv = rv .pint .quantify ()
30
+ rv = rv .pint .quantify (unit_registry = unit_registry )
30
31
else :
31
32
if attrs is not None :
32
33
# Necessary to use the Q_ and not simply multiplication with ureg unit because of temperature
33
34
# see https://pint.readthedocs.io/en/latest/nonmult.html
34
- rv = Q_ (rv , attrs ["units" ])
35
+ rv = unit_registry . Quantity (rv , attrs ["units" ])
35
36
return rv
36
37
37
38
38
39
def pint_compat (args , kwargs ):
39
40
if pint_xarray is None :
40
- return args , kwargs , False
41
+ return args , kwargs , None
41
42
42
43
using_pint = False
43
44
new_args = []
44
45
new_kwargs = {}
46
+ registries = []
45
47
for arg in args :
46
48
if isinstance (arg , xr .DataArray ):
47
49
if arg .pint .units is not None :
48
50
new_args .append (arg .pint .dequantify ())
49
- using_pint = True
51
+ registries . append ( arg . pint . registry )
50
52
else :
51
53
new_args .append (arg )
52
54
elif isinstance (arg , pint .Quantity ):
53
55
new_args .append (arg .magnitude )
54
- using_pint = True
56
+ registries . append ( arg . _REGISTRY )
55
57
else :
56
58
new_args .append (arg )
57
59
58
60
for kw , arg in kwargs .items ():
59
61
if isinstance (arg , xr .DataArray ):
60
62
if arg .pint .units is not None :
61
63
new_kwargs [kw ] = arg .pint .dequantify ()
62
- using_pint = True
64
+ registries . append ( arg . pint . registry )
63
65
else :
64
66
new_kwargs [kw ] = arg
65
67
elif isinstance (arg , pint .Quantity ):
66
68
new_kwargs [kw ] = arg .magnitude
67
- using_pint = True
69
+ registries . append ( arg . _REGISTRY )
68
70
else :
69
71
new_kwargs [kw ] = arg
70
72
71
- return new_args , new_kwargs , using_pint
73
+ registries = set (registries )
74
+ if len (registries ) > 1 :
75
+ raise ValueError ("Quantity arguments must all belong to the same unit registry" )
76
+ elif len (registries ) == 0 :
77
+ registries = None
78
+ else :
79
+ (registries ,) = registries
80
+ return new_args , new_kwargs , registries
72
81
73
82
74
83
def cf_attrs (attrs , name , check_func ):
75
84
def cf_attrs_decorator (func ):
76
85
@wraps (func )
77
86
def cf_attrs_wrapper (* args , ** kwargs ):
78
- args , kwargs , is_quantity = pint_compat (args , kwargs )
87
+ args , kwargs , unit_registry = pint_compat (args , kwargs )
79
88
rv = func (* args , ** kwargs )
80
89
attrs_checked = check_func (attrs , args , kwargs )
81
90
if isinstance (rv , tuple ):
82
91
rv_updated = []
83
92
for (i , da ) in enumerate (rv ):
84
93
add_attrs (da , attrs_checked [i ], name [i ])
85
- if is_quantity :
86
- rv_updated .append (quantify (da , attrs_checked [i ]))
87
- else :
88
- rv_updated .append (da )
94
+ rv_updated .append (
95
+ quantify (da , attrs_checked [i ], unit_registry = unit_registry )
96
+ )
89
97
90
98
rv = tuple (rv_updated )
91
99
92
100
else :
93
101
add_attrs (rv , attrs_checked , name )
94
- if is_quantity :
95
- rv = quantify (rv , attrs_checked )
102
+ rv = quantify (rv , attrs_checked , unit_registry = unit_registry )
96
103
return rv
97
104
98
105
return cf_attrs_wrapper
0 commit comments