Skip to content

Commit 14e7c00

Browse files
authored
Merge pull request #966 from DenverCoderOne/number-type-improvements
Numbers and core type fixes
2 parents 25e4360 + 053c242 commit 14e7c00

File tree

2 files changed

+36
-44
lines changed

2 files changed

+36
-44
lines changed

babel/core.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import os
1414
import pickle
1515
from collections.abc import Iterable, Mapping
16-
from typing import TYPE_CHECKING, Any, overload
16+
from typing import TYPE_CHECKING, Any
1717

1818
from babel import localedata
1919
from babel.plural import PluralRule
@@ -260,21 +260,13 @@ def negotiate(
260260
if identifier:
261261
return Locale.parse(identifier, sep=sep)
262262

263-
@overload
264-
@classmethod
265-
def parse(cls, identifier: None, sep: str = ..., resolve_likely_subtags: bool = ...) -> None: ...
266-
267-
@overload
268-
@classmethod
269-
def parse(cls, identifier: str | Locale, sep: str = ..., resolve_likely_subtags: bool = ...) -> Locale: ...
270-
271263
@classmethod
272264
def parse(
273265
cls,
274266
identifier: str | Locale | None,
275267
sep: str = '_',
276268
resolve_likely_subtags: bool = True,
277-
) -> Locale | None:
269+
) -> Locale:
278270
"""Create a `Locale` instance for the given locale identifier.
279271
280272
>>> l = Locale.parse('de-DE', sep='-')
@@ -317,10 +309,9 @@ def parse(
317309
identifier
318310
:raise `UnknownLocaleError`: if no locale data is available for the
319311
requested locale
312+
:raise `TypeError`: if the identifier is not a string or a `Locale`
320313
"""
321-
if identifier is None:
322-
return None
323-
elif isinstance(identifier, Locale):
314+
if isinstance(identifier, Locale):
324315
return identifier
325316
elif not isinstance(identifier, str):
326317
raise TypeError(f"Unexpected value for identifier: {identifier!r}")
@@ -364,9 +355,9 @@ def _try_load_reducing(parts):
364355
language, territory, script, variant = parts
365356
modifier = None
366357
language = get_global('language_aliases').get(language, language)
367-
territory = get_global('territory_aliases').get(territory, (territory,))[0]
368-
script = get_global('script_aliases').get(script, script)
369-
variant = get_global('variant_aliases').get(variant, variant)
358+
territory = get_global('territory_aliases').get(territory or '', (territory,))[0]
359+
script = get_global('script_aliases').get(script or '', script)
360+
variant = get_global('variant_aliases').get(variant or '', variant)
370361

371362
if territory == 'ZZ':
372363
territory = None
@@ -389,9 +380,9 @@ def _try_load_reducing(parts):
389380
if likely_subtag is not None:
390381
parts2 = parse_locale(likely_subtag)
391382
if len(parts2) == 5:
392-
language2, _, script2, variant2, modifier2 = parse_locale(likely_subtag)
383+
language2, _, script2, variant2, modifier2 = parts2
393384
else:
394-
language2, _, script2, variant2 = parse_locale(likely_subtag)
385+
language2, _, script2, variant2 = parts2
395386
modifier2 = None
396387
locale = _try_load_reducing((language2, territory, script2, variant2, modifier2))
397388
if locale is not None:
@@ -512,7 +503,7 @@ def get_territory_name(self, locale: Locale | str | None = None) -> str | None:
512503
if locale is None:
513504
locale = self
514505
locale = Locale.parse(locale)
515-
return locale.territories.get(self.territory)
506+
return locale.territories.get(self.territory or '')
516507

517508
territory_name = property(get_territory_name, doc="""\
518509
The localized territory name of the locale if available.
@@ -526,7 +517,7 @@ def get_script_name(self, locale: Locale | str | None = None) -> str | None:
526517
if locale is None:
527518
locale = self
528519
locale = Locale.parse(locale)
529-
return locale.scripts.get(self.script)
520+
return locale.scripts.get(self.script or '')
530521

531522
script_name = property(get_script_name, doc="""\
532523
The localized script name of the locale if available.
@@ -1147,7 +1138,7 @@ def negotiate_locale(preferred: Iterable[str], available: Iterable[str], sep: st
11471138
def parse_locale(
11481139
identifier: str,
11491140
sep: str = '_'
1150-
) -> tuple[str, str | None, str | None, str | None, str | None]:
1141+
) -> tuple[str, str | None, str | None, str | None] | tuple[str, str | None, str | None, str | None, str | None]:
11511142
"""Parse a locale identifier into a tuple of the form ``(language,
11521143
territory, script, variant, modifier)``.
11531144
@@ -1261,7 +1252,7 @@ def get_locale_identifier(
12611252
:param tup: the tuple as returned by :func:`parse_locale`.
12621253
:param sep: the separator for the identifier.
12631254
"""
1264-
tup = tuple(tup[:5])
1255+
tup = tuple(tup[:5]) # type: ignore # length should be no more than 5
12651256
lang, territory, script, variant, modifier = tup + (None,) * (5 - len(tup))
12661257
ret = sep.join(filter(None, (lang, script, territory, variant)))
12671258
return f'{ret}@{modifier}' if modifier else ret

babel/numbers.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import decimal
2424
import re
2525
import warnings
26-
from typing import TYPE_CHECKING, Any, overload
26+
from typing import TYPE_CHECKING, Any, cast, overload
2727

2828
from babel.core import Locale, default_locale, get_global
2929
from babel.localedata import LocaleDataDict
@@ -428,7 +428,7 @@ def get_decimal_quantum(precision: int | decimal.Decimal) -> decimal.Decimal:
428428

429429
def format_decimal(
430430
number: float | decimal.Decimal | str,
431-
format: str | None = None,
431+
format: str | NumberPattern | None = None,
432432
locale: Locale | str | None = LC_NUMERIC,
433433
decimal_quantization: bool = True,
434434
group_separator: bool = True,
@@ -474,8 +474,8 @@ def format_decimal(
474474
number format.
475475
"""
476476
locale = Locale.parse(locale)
477-
if not format:
478-
format = locale.decimal_formats.get(format)
477+
if format is None:
478+
format = locale.decimal_formats[format]
479479
pattern = parse_pattern(format)
480480
return pattern.apply(
481481
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
@@ -513,15 +513,15 @@ def format_compact_decimal(
513513
number, format = _get_compact_format(number, compact_format, locale, fraction_digits)
514514
# Did not find a format, fall back.
515515
if format is None:
516-
format = locale.decimal_formats.get(None)
516+
format = locale.decimal_formats[None]
517517
pattern = parse_pattern(format)
518518
return pattern.apply(number, locale, decimal_quantization=False)
519519

520520

521521
def _get_compact_format(
522522
number: float | decimal.Decimal | str,
523523
compact_format: LocaleDataDict,
524-
locale: Locale | str | None,
524+
locale: Locale,
525525
fraction_digits: int,
526526
) -> tuple[decimal.Decimal, NumberPattern | None]:
527527
"""Returns the number after dividing by the unit and the format pattern to use.
@@ -543,7 +543,7 @@ def _get_compact_format(
543543
break
544544
# otherwise, we need to divide the number by the magnitude but remove zeros
545545
# equal to the number of 0's in the pattern minus 1
546-
number = number / (magnitude // (10 ** (pattern.count("0") - 1)))
546+
number = cast(decimal.Decimal, number / (magnitude // (10 ** (pattern.count("0") - 1))))
547547
# round to the number of fraction digits requested
548548
rounded = round(number, fraction_digits)
549549
# if the remaining number is singular, use the singular format
@@ -565,7 +565,7 @@ class UnknownCurrencyFormatError(KeyError):
565565
def format_currency(
566566
number: float | decimal.Decimal | str,
567567
currency: str,
568-
format: str | None = None,
568+
format: str | NumberPattern | None = None,
569569
locale: Locale | str | None = LC_NUMERIC,
570570
currency_digits: bool = True,
571571
format_type: Literal["name", "standard", "accounting"] = "standard",
@@ -680,7 +680,7 @@ def format_currency(
680680
def _format_currency_long_name(
681681
number: float | decimal.Decimal | str,
682682
currency: str,
683-
format: str | None = None,
683+
format: str | NumberPattern | None = None,
684684
locale: Locale | str | None = LC_NUMERIC,
685685
currency_digits: bool = True,
686686
format_type: Literal["name", "standard", "accounting"] = "standard",
@@ -706,7 +706,7 @@ def _format_currency_long_name(
706706

707707
# Step 5.
708708
if not format:
709-
format = locale.decimal_formats.get(format)
709+
format = locale.decimal_formats[format]
710710

711711
pattern = parse_pattern(format)
712712

@@ -758,13 +758,15 @@ def format_compact_currency(
758758
# compress adjacent spaces into one
759759
format = re.sub(r'(\s)\s+', r'\1', format).strip()
760760
break
761+
if format is None:
762+
raise ValueError('No compact currency format found for the given number and locale.')
761763
pattern = parse_pattern(format)
762764
return pattern.apply(number, locale, currency=currency, currency_digits=False, decimal_quantization=False)
763765

764766

765767
def format_percent(
766768
number: float | decimal.Decimal | str,
767-
format: str | None = None,
769+
format: str | NumberPattern | None = None,
768770
locale: Locale | str | None = LC_NUMERIC,
769771
decimal_quantization: bool = True,
770772
group_separator: bool = True,
@@ -808,15 +810,15 @@ def format_percent(
808810
"""
809811
locale = Locale.parse(locale)
810812
if not format:
811-
format = locale.percent_formats.get(format)
813+
format = locale.percent_formats[format]
812814
pattern = parse_pattern(format)
813815
return pattern.apply(
814816
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
815817

816818

817819
def format_scientific(
818820
number: float | decimal.Decimal | str,
819-
format: str | None = None,
821+
format: str | NumberPattern | None = None,
820822
locale: Locale | str | None = LC_NUMERIC,
821823
decimal_quantization: bool = True,
822824
) -> str:
@@ -847,7 +849,7 @@ def format_scientific(
847849
"""
848850
locale = Locale.parse(locale)
849851
if not format:
850-
format = locale.scientific_formats.get(format)
852+
format = locale.scientific_formats[format]
851853
pattern = parse_pattern(format)
852854
return pattern.apply(
853855
number, locale, decimal_quantization=decimal_quantization)
@@ -856,7 +858,7 @@ def format_scientific(
856858
class NumberFormatError(ValueError):
857859
"""Exception raised when a string cannot be parsed into a number."""
858860

859-
def __init__(self, message: str, suggestions: str | None = None) -> None:
861+
def __init__(self, message: str, suggestions: list[str] | None = None) -> None:
860862
super().__init__(message)
861863
#: a list of properly formatted numbers derived from the invalid input
862864
self.suggestions = suggestions
@@ -1140,7 +1142,7 @@ def scientific_notation_elements(self, value: decimal.Decimal, locale: Locale |
11401142

11411143
def apply(
11421144
self,
1143-
value: float | decimal.Decimal,
1145+
value: float | decimal.Decimal | str,
11441146
locale: Locale | str | None,
11451147
currency: str | None = None,
11461148
currency_digits: bool = True,
@@ -1211,9 +1213,9 @@ def apply(
12111213
number = ''.join([
12121214
self._quantize_value(value, locale, frac_prec, group_separator),
12131215
get_exponential_symbol(locale),
1214-
exp_sign,
1215-
self._format_int(
1216-
str(exp), self.exp_prec[0], self.exp_prec[1], locale)])
1216+
exp_sign, # type: ignore # exp_sign is always defined here
1217+
self._format_int(str(exp), self.exp_prec[0], self.exp_prec[1], locale) # type: ignore # exp is always defined here
1218+
])
12171219

12181220
# Is it a significant digits pattern?
12191221
elif '@' in self.pattern:
@@ -1234,9 +1236,8 @@ def apply(
12341236
number if self.number_pattern != '' else '',
12351237
self.suffix[is_negative]])
12361238

1237-
if '¤' in retval:
1238-
retval = retval.replace('¤¤¤',
1239-
get_currency_name(currency, value, locale))
1239+
if '¤' in retval and currency is not None:
1240+
retval = retval.replace('¤¤¤', get_currency_name(currency, value, locale))
12401241
retval = retval.replace('¤¤', currency.upper())
12411242
retval = retval.replace('¤', get_currency_symbol(currency, locale))
12421243

0 commit comments

Comments
 (0)