Skip to content

Commit ce04668

Browse files
authored
Add custom UnitRegistry to handle scaling factors more cleanly (#212)
* Add custom UnitRegistry to handle scaling factors more cleanly * Maintain support for base pint Unit objects
1 parent 6df6538 commit ce04668

File tree

5 files changed

+103
-45
lines changed

5 files changed

+103
-45
lines changed

gemd/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.0.0"
1+
__version__ = "2.1.0"

gemd/units/impl.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implementation of units."""
2+
from deprecation import deprecated
23
import functools
34
from importlib.resources import read_text
45
import os
@@ -7,7 +8,7 @@
78
from tempfile import TemporaryDirectory
89
from typing import Union, List, Tuple, Generator, Any
910

10-
from pint import UnitRegistry, Unit, register_unit_format
11+
from pint import UnitRegistry, register_unit_format
1112
try: # Pint 0.23 migrated the location of this method, and augmented it
1213
from pint.pint_eval import tokenizer
1314
except ImportError: # pragma: no cover
@@ -131,6 +132,7 @@ def _scaling_identify_factors(
131132
"""
132133
todo = []
133134
for block in blocks:
135+
# Note: while Python does not recognize ^ as exponentiation, pint does
134136
i_exp = next((i for i, t in enumerate(block) if t.string in {"**", "^"}), len(block))
135137
i_name = next((i for i, t in enumerate(block) if t.type == NAME), None)
136138
numbers = [(i, t.string) for i, t in enumerate(block) if t.type == NUMBER and i < i_exp]
@@ -168,10 +170,14 @@ def _scaling_store_and_mangle(input_string: str, todo: List[Tuple[str, str, str]
168170
"""
169171
for scaled_term, number_string, unit_string in todo:
170172
regex = rf"(?<![-+0-9.]){re.escape(scaled_term)}(?![0-9.])"
171-
stripped = re.sub(r"[+\s]+", "", scaled_term).replace("--", "")
173+
stripped = re.sub(
174+
r"(?<=\d)_(?=\d)", "", re.sub(r"[+\s]+", "", scaled_term).replace("--", "")
175+
)
172176

173177
if unit_string is not None:
174-
stripped_unit = re.sub(r"[+\s]+", "", unit_string).replace("--", "")
178+
stripped_unit = re.sub(
179+
r"(?<!0)(?=\.)", "0", re.sub(r"[+\s]+", "", unit_string)
180+
).replace("--", "")
175181
long_unit = f"{_REGISTRY.parse_units(stripped_unit)}"
176182
short_unit = f"{_REGISTRY.parse_units(stripped_unit):~}"
177183
long = stripped.replace(stripped_unit, "_" + long_unit)
@@ -201,7 +207,58 @@ def _scaling_preprocessor(input_string: str) -> str:
201207
return _scaling_store_and_mangle(input_string, todo)
202208

203209

204-
_REGISTRY: UnitRegistry = None # global requires it be defined in this scope
210+
def _unmangle_scaling(input_string: str) -> str:
211+
"""Convert mangled scaling values into a pint-compatible expression."""
212+
number_re = r'\b_(_)?(\d+)(_\d+)?([eE]_?\d+)?(_(?=[a-zA-Z]))?'
213+
while match := re.search(number_re, input_string):
214+
replacement = '' if match.group(1) is None else '-'
215+
replacement += match.group(2)
216+
replacement += '' if match.group(3) is None else match.group(3).replace('_', '.')
217+
replacement += '' if match.group(4) is None else match.group(4).replace('_', '-')
218+
replacement += '' if match.group(5) is None else match.group(5).replace('_', ' ')
219+
input_string = input_string.replace(match.group(0), replacement)
220+
return input_string
221+
222+
223+
try: # pragma: no cover
224+
# Pint 0.23 modified the preferred way to derive a custom class
225+
# https://pint.readthedocs.io/en/0.23/advanced/custom-registry-class.html
226+
from pint.registry import GenericUnitRegistry
227+
from typing_extensions import TypeAlias
228+
229+
class _ScaleFactorUnit(UnitRegistry.Unit):
230+
"""Child class of Units for generating units w/ clean scaling factors."""
231+
232+
def __format__(self, format_spec):
233+
result = super().__format__(format_spec)
234+
return _unmangle_scaling(result)
235+
236+
class _ScaleFactorQuantity(UnitRegistry.Quantity):
237+
"""Child class of Quantity for generating units w/ clean scaling factors."""
238+
239+
pass
240+
241+
class _ScaleFactorRegistry(GenericUnitRegistry[_ScaleFactorQuantity, _ScaleFactorUnit]):
242+
"""UnitRegistry class that uses _GemdUnits."""
243+
244+
Quantity: TypeAlias = _ScaleFactorQuantity
245+
Unit: TypeAlias = _ScaleFactorUnit
246+
247+
except ImportError: # pragma: no cover
248+
# https://pint.readthedocs.io/en/0.21/advanced/custom-registry-class.html
249+
class _ScaleFactorUnit(UnitRegistry.Unit):
250+
"""Child class of Units for generating units w/ clean scaling factors."""
251+
252+
def __format__(self, format_spec):
253+
result = super().__format__(format_spec)
254+
return _unmangle_scaling(result)
255+
256+
class _ScaleFactorRegistry(UnitRegistry):
257+
"""UnitRegistry class that uses _GemdUnits."""
258+
259+
_unit_class = _ScaleFactorUnit
260+
261+
_REGISTRY: _ScaleFactorRegistry = None # global requires it be defined in this scope
205262

206263

207264
@functools.lru_cache(maxsize=1024 * 1024)
@@ -244,38 +301,23 @@ def convert_units(value: float, starting_unit: str, final_unit: str) -> float:
244301

245302

246303
@register_unit_format("clean")
304+
@deprecated(deprecated_in="2.1.0", removed_in="3.0.0", details="Scaling factor clean-up ")
247305
def _format_clean(unit, registry, **options):
248-
"""Formatter that turns scaling-factor-units into numbers again."""
249-
numerator = []
250-
denominator = []
251-
for u, p in unit.items():
252-
if re.match(r"_[\d_]+", u):
253-
# Munged scaling factor; grab symbol, which is the prettier
254-
u = registry.get_symbol(u)
255-
256-
if p == 1:
257-
numerator.append(u)
258-
elif p > 0:
259-
numerator.append(f"{u} ** {p}")
260-
elif p == -1:
261-
denominator.append(u)
262-
elif p < 0:
263-
denominator.append(f"{u} ** {-p}")
264-
265-
if len(numerator) == 0:
266-
numerator = ["1"]
267-
268-
if len(denominator) > 0:
269-
return " / ".join((" * ".join(numerator), " / ".join(denominator)))
270-
else:
271-
return " * ".join(numerator)
306+
"""
307+
DEPRECATED Formatter that turns scaling-factor-units into numbers again.
308+
309+
Responsibility for this piece of clean-up has been shifted to a custom class.
310+
311+
"""
312+
from pint.formatting import _FORMATTERS
313+
return _FORMATTERS["D"](unit, registry, **options)
272314

273315

274316
@functools.lru_cache(maxsize=1024)
275-
def parse_units(units: Union[str, Unit, None],
317+
def parse_units(units: Union[str, UnitRegistry.Unit, None],
276318
*,
277319
return_unit: bool = False
278-
) -> Union[str, Unit, None]:
320+
) -> Union[str, UnitRegistry.Unit, None]:
279321
"""
280322
Parse a string or Unit into a standard string representation of the unit.
281323
@@ -298,19 +340,20 @@ def parse_units(units: Union[str, Unit, None],
298340
else:
299341
return None
300342
elif isinstance(units, str):
301-
parsed = _REGISTRY.parse_units(units)
343+
# SPT-1311 Protect against leaked mangled strings
344+
parsed = _REGISTRY.parse_units(_unmangle_scaling(units))
302345
if return_unit:
303346
return parsed
304347
else:
305-
return f"{parsed:clean}"
306-
elif isinstance(units, Unit):
348+
return f"{parsed}"
349+
elif isinstance(units, UnitRegistry.Unit):
307350
return units
308351
else:
309352
raise UndefinedUnitError("Units must be given as a recognized unit string or Units object")
310353

311354

312355
@functools.lru_cache(maxsize=1024)
313-
def get_base_units(units: Union[str, Unit]) -> Tuple[Unit, float, float]:
356+
def get_base_units(units: Union[str, UnitRegistry.Unit]) -> Tuple[UnitRegistry.Unit, float, float]:
314357
"""
315358
Get the base units and conversion factors for the given unit.
316359
@@ -358,13 +401,13 @@ def change_definitions_file(filename: str = None):
358401
path = Path(target)
359402
os.chdir(path.parent)
360403
# Need to re-verify path because of some slippiness around tmp on MacOS
361-
_REGISTRY = UnitRegistry(filename=Path.cwd() / path.name,
362-
preprocessors=[_space_after_minus_preprocessor,
363-
_scientific_notation_preprocessor,
364-
_scaling_preprocessor
365-
],
366-
autoconvert_offset_to_baseunit=True
367-
)
404+
_REGISTRY = _ScaleFactorRegistry(filename=Path.cwd() / path.name,
405+
preprocessors=[_space_after_minus_preprocessor,
406+
_scientific_notation_preprocessor,
407+
_scaling_preprocessor
408+
],
409+
autoconvert_offset_to_baseunit=True
410+
)
368411
finally:
369412
os.chdir(current_dir)
370413

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pint==0.20
22
deprecation==2.1.0
3+
typing-extensions==4.8.0

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
},
3838
install_requires=[
3939
"pint>=0.20,<0.24",
40-
"deprecation>=2.1.0,<3"
40+
"deprecation>=2.1.0,<3",
41+
"typing_extensions>=4.8,<5"
4142
],
4243
extras_require={
4344
"tests": [

tests/units/test_parser.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextlib import contextmanager
2+
from deprecation import DeprecatedWarning
23
from importlib.resources import read_binary
34
import re
45
from pint import UnitRegistry
@@ -30,7 +31,10 @@ def test_parse_expected(return_unit):
3031
"g / -+-25e-1 m", # Weird but fine
3132
"ug / - -250 mL", # Spaces between unaries is acceptable
3233
"1 / 10**5 degC", # Spaces between unaries is acceptable
33-
"m ** - 1" # Pint < 0.21 throws DefinitionSyntaxError
34+
"1 / 10_000 degC", # Spaces between unaries is acceptable
35+
"m ** - 1", # Pint < 0.21 throws DefinitionSyntaxError
36+
"gram / _10_minute", # Stringified Unit object SPT-1311
37+
"gram / __1_2e_3minute", # Stringified Unit object SPT-1311
3438
]
3539
for unit in expected:
3640
parsed = parse_units(unit, return_unit=return_unit)
@@ -205,8 +209,17 @@ def test_exponents():
205209

206210
def test__scientific_notation_preprocessor():
207211
"""Verify that numbers are converted into scientific notation."""
208-
assert "1e2 kg" in parse_units("F* 10 ** 2 kg")
212+
assert "1e2 kilogram" in parse_units("F* 10 ** 2 kg")
213+
assert "1e2 kg" in f'{parse_units("F* 10 ** 2 kg", return_unit=True):~}'
209214
assert "1e-5" in parse_units("F* mm*10**-5")
210215
assert "1e" not in parse_units("F* kg * 10 cm")
211216
assert "-3.07e2" in parse_units("F* -3.07 * 10 ** 2")
212217
assert "11e2" in parse_units("F* 11*10^2")
218+
219+
220+
def test_deprecation():
221+
"""Make sure deprecated things warn correctly."""
222+
megapascals = parse_units("MPa", return_unit=True)
223+
with pytest.warns(DeprecatedWarning):
224+
stringified = f"{megapascals:clean}"
225+
assert megapascals == parse_units(stringified, return_unit=True)

0 commit comments

Comments
 (0)