Skip to content

Commit 531031d

Browse files
authored
Add support for complex, Decimal, and Fraction suffixes (#148)
* Add support for complex, Decimal, and Fraction suffixes * Fix type hints * Ignore linter * Add extra tests and convert impossible cases to asserts
1 parent 581933d commit 531031d

File tree

6 files changed

+215
-30
lines changed

6 files changed

+215
-30
lines changed

basilisp/compiler.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import uuid
99
from collections import OrderedDict
1010
from datetime import datetime
11+
from decimal import Decimal
1112
from enum import Enum
13+
from fractions import Fraction
1214
from itertools import chain
1315
from typing import (Dict, Iterable, Pattern, Tuple, Optional, List, Union, Callable, Mapping, NamedTuple, cast, Deque,
1416
Any)
@@ -393,6 +395,8 @@ def _expressionize(body: MixedNodeStream,
393395
_NS_VAR_VALUE = f'{_NS_VAR}.value'
394396

395397
_NS_VAR_NAME = _load_attr(f'{_NS_VAR_VALUE}.name')
398+
_NEW_DECIMAL_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.decimal_from_str')
399+
_NEW_FRACTION_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.fraction')
396400
_NEW_INST_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.inst_from_str')
397401
_NEW_KW_FN_NAME = _load_attr(f'{_KW_ALIAS}.keyword')
398402
_NEW_LIST_FN_NAME = _load_attr(f'{_LIST_ALIAS}.list')
@@ -1570,16 +1574,29 @@ def _sym_ast(ctx: CompilerContext, form: sym.Symbol) -> ASTStream:
15701574
ctx=ast.Load()))
15711575

15721576

1573-
def _regex_ast(_: CompilerContext, form: Pattern) -> ASTStream:
1577+
def _decimal_ast(_: CompilerContext, form: Decimal) -> ASTStream:
15741578
yield _node(ast.Call(
1575-
func=_NEW_REGEX_FN_NAME, args=[ast.Str(form.pattern)], keywords=[]))
1579+
func=_NEW_DECIMAL_FN_NAME, args=[ast.Str(str(form))], keywords=[]))
1580+
1581+
1582+
def _fraction_ast(_: CompilerContext, form: Fraction) -> ASTStream:
1583+
yield _node(ast.Call(
1584+
func=_NEW_FRACTION_FN_NAME,
1585+
args=[ast.Num(form.numerator),
1586+
ast.Num(form.denominator)],
1587+
keywords=[]))
15761588

15771589

15781590
def _inst_ast(_: CompilerContext, form: datetime) -> ASTStream:
15791591
yield _node(ast.Call(
15801592
func=_NEW_INST_FN_NAME, args=[ast.Str(form.isoformat())], keywords=[]))
15811593

15821594

1595+
def _regex_ast(_: CompilerContext, form: Pattern) -> ASTStream:
1596+
yield _node(ast.Call(
1597+
func=_NEW_REGEX_FN_NAME, args=[ast.Str(form.pattern)], keywords=[]))
1598+
1599+
15831600
def _uuid_ast(_: CompilerContext, form: uuid.UUID) -> ASTStream:
15841601
yield _node(ast.Call(
15851602
func=_NEW_UUID_FN_NAME, args=[ast.Str(str(form))], keywords=[]))
@@ -1675,12 +1692,18 @@ def _to_ast(ctx: CompilerContext, form: LispForm) -> ASTStream: # pylint: disab
16751692
elif isinstance(form, float):
16761693
yield _node(ast.Num(form))
16771694
return
1678-
elif isinstance(form, int):
1695+
elif isinstance(form, (complex, int)):
16791696
yield _node(ast.Num(form))
16801697
return
16811698
elif isinstance(form, datetime):
16821699
yield from _inst_ast(ctx, form)
16831700
return
1701+
elif isinstance(form, Decimal):
1702+
yield from _decimal_ast(ctx, form)
1703+
return
1704+
elif isinstance(form, Fraction):
1705+
yield from _fraction_ast(ctx, form)
1706+
return
16841707
elif isinstance(form, uuid.UUID):
16851708
yield from _uuid_ast(ctx, form)
16861709
return

basilisp/lang/typing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22
from datetime import datetime
3+
from decimal import Decimal
34
from fractions import Fraction
45
from typing import Union, Pattern
56

@@ -11,7 +12,7 @@
1112
import basilisp.lang.vector as vec
1213

1314
LispNumber = Union[int, float, Fraction]
14-
LispForm = Union[bool, datetime, int, float, kw.Keyword, llist.List,
15-
lmap.Map, None, Pattern, lset.Set, str, sym.Symbol,
16-
vec.Vector, uuid.UUID]
15+
LispForm = Union[bool, complex, datetime, Decimal, int, float, Fraction,
16+
kw.Keyword, llist.List, lmap.Map, None, Pattern, lset.Set,
17+
str, sym.Symbol, vec.Vector, uuid.UUID]
1718
IterableLispForm = Union[llist.List, lmap.Map, lset.Set, vec.Vector]

basilisp/lang/util.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import keyword
44
import re
55
import uuid
6+
from decimal import Decimal
67
from fractions import Fraction
78
from typing import Pattern
89

@@ -21,16 +22,20 @@ def lrepr(f) -> str:
2122
return "nil"
2223
elif isinstance(f, str):
2324
return f'"{f}"'
25+
elif isinstance(f, complex):
26+
return repr(f).upper()
2427
elif isinstance(f, datetime.datetime):
2528
inst_str = f.isoformat()
2629
return f'#inst "{inst_str}"'
30+
elif isinstance(f, Decimal):
31+
return str(f)
32+
elif isinstance(f, Fraction):
33+
return f"{f.numerator}/{f.denominator}"
34+
elif isinstance(f, Pattern):
35+
return f'#"{f.pattern}"'
2736
elif isinstance(f, uuid.UUID):
2837
uuid_str = str(f)
2938
return f'#uuid "{uuid_str}"'
30-
elif isinstance(f, Pattern):
31-
return f'#"{f.pattern}"'
32-
elif isinstance(f, Fraction):
33-
return f"{f.numerator}/{f.denominator}"
3439
else:
3540
return repr(f)
3641

@@ -85,6 +90,16 @@ def genname(prefix: str) -> str:
8590
return f"{prefix}_{i}"
8691

8792

93+
def decimal_from_str(decimal_str: str) -> Decimal:
94+
"""Create a Decimal from a numeric string."""
95+
return Decimal(decimal_str)
96+
97+
98+
def fraction(numerator: int, denominator: int) -> Fraction:
99+
"""Create a Fraction from a numerator and denominator."""
100+
return Fraction(numerator=numerator, denominator=denominator)
101+
102+
88103
def inst_from_str(inst_str: str) -> datetime.datetime:
89104
"""Create a datetime instance from an RFC 3339 formatted date string."""
90105
return dateparser.parse(inst_str)

basilisp/reader.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import collections
22
import contextlib
3+
import decimal
34
import functools
45
import io
56
import re
67
import uuid
78
from datetime import datetime
9+
from fractions import Fraction
810
from typing import (Deque, List, Tuple, Optional, Collection, Callable, Any, Union, MutableMapping, Pattern, Iterable,
911
TypeVar, cast, Dict)
1012

@@ -408,14 +410,19 @@ def _read_map(ctx: ReaderContext) -> lmap.Map:
408410
# special keywords `true`, `false`, and `nil`, we have to have a looser
409411
# type defined for the return from these reader functions.
410412
MaybeSymbol = Union[bool, None, symbol.Symbol]
411-
MaybeNumber = Union[float, int, MaybeSymbol]
413+
MaybeNumber = Union[complex, decimal.Decimal, float, Fraction, int, MaybeSymbol]
412414

413415

414-
def _read_num(ctx: ReaderContext) -> MaybeNumber:
415-
"""Return a numeric (integer or float) from the input stream."""
416+
def _read_num(ctx: ReaderContext) -> MaybeNumber: # noqa: C901 # pylint: disable=too-many-statements
417+
"""Return a numeric (complex, Decimal, float, int, Fraction) from the input stream."""
416418
chars: List[str] = []
417419
reader = ctx.reader
420+
421+
is_complex = False
422+
is_decimal = False
418423
is_float = False
424+
is_integer = False
425+
is_ratio = False
419426
while True:
420427
token = reader.peek()
421428
if token == '-':
@@ -435,16 +442,55 @@ def _read_num(ctx: ReaderContext) -> MaybeNumber:
435442
raise SyntaxError(
436443
"Found extra '.' in float; expected decimal portion")
437444
is_float = True
445+
elif token == 'J':
446+
if is_complex:
447+
raise SyntaxError("Found extra 'J' suffix in complex literal")
448+
is_complex = True
449+
elif token == 'M':
450+
if is_decimal:
451+
raise SyntaxError("Found extra 'M' suffix in decimal literal")
452+
is_decimal = True
453+
elif token == 'N':
454+
if is_integer:
455+
raise SyntaxError("Found extra 'N' suffix in integer literal")
456+
is_integer = True
457+
elif token == '/':
458+
if is_ratio:
459+
raise SyntaxError("Found extra '/' in ratio literal")
460+
is_ratio = True
438461
elif not num_chars.match(token):
439462
break
440463
reader.next_token()
441464
chars.append(token)
442465

443-
if len(chars) == 0:
444-
raise SyntaxError("Expected integer or float")
466+
assert len(chars) > 0, "Must have at least one digit in integer or float"
445467

446468
s = ''.join(chars)
447-
return float(s) if is_float else int(s)
469+
if sum([is_complex and is_decimal,
470+
is_complex and is_integer,
471+
is_complex and is_ratio,
472+
is_decimal or is_float,
473+
is_integer,
474+
is_ratio]) > 1:
475+
raise SyntaxError(f"Invalid number format: {s}")
476+
477+
if is_complex:
478+
imaginary = float(s[:-1]) if is_float else int(s[:-1])
479+
return complex(0, imaginary)
480+
elif is_decimal:
481+
try:
482+
return decimal.Decimal(s[:-1])
483+
except decimal.InvalidOperation:
484+
raise SyntaxError(f"Invalid number format: {s}") from None
485+
elif is_float:
486+
return float(s)
487+
elif is_ratio:
488+
assert "/" in s, "Ratio must contain one '/' character"
489+
num, denominator = s.split('/')
490+
return Fraction(numerator=int(num), denominator=int(denominator))
491+
elif is_integer:
492+
return int(s[:-1])
493+
return int(s)
448494

449495

450496
def _read_str(ctx: ReaderContext) -> str:

tests/compiler_test.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import decimal
12
import re
23
import types
34
import uuid
5+
from fractions import Fraction
46
from typing import Optional
57
from unittest.mock import Mock
68

@@ -77,12 +79,30 @@ def test_string():
7779

7880

7981
def test_int():
80-
assert lcompile('1') == 1
81-
assert lcompile('100') == 100
82-
assert lcompile('99927273') == 99927273
83-
assert lcompile('0') == 0
84-
assert lcompile('-1') == -1
85-
assert lcompile('-538282') == -538282
82+
assert 1 == lcompile('1')
83+
assert 100 == lcompile('100')
84+
assert 99927273 == lcompile('99927273')
85+
assert 0 == lcompile('0')
86+
assert -1 == lcompile('-1')
87+
assert -538282 == lcompile('-538282')
88+
89+
assert 1 == lcompile('1N')
90+
assert 100 == lcompile('100N')
91+
assert 99927273 == lcompile('99927273N')
92+
assert 0 == lcompile('0N')
93+
assert -1 == lcompile('-1N')
94+
assert -538282 == lcompile('-538282N')
95+
96+
97+
def test_decimal():
98+
assert decimal.Decimal('0.0') == lcompile('0.0M')
99+
assert decimal.Decimal('0.09387372') == lcompile('0.09387372M')
100+
assert decimal.Decimal('1.0') == lcompile('1.0M')
101+
assert decimal.Decimal('1.332') == lcompile('1.332M')
102+
assert decimal.Decimal('-1.332') == lcompile('-1.332M')
103+
assert decimal.Decimal('-1.0') == lcompile('-1.0M')
104+
assert decimal.Decimal('-0.332') == lcompile('-0.332M')
105+
assert decimal.Decimal('3.14') == lcompile('3.14M')
86106

87107

88108
def test_float():
@@ -659,15 +679,19 @@ def test_var(ns_var: Var):
659679
assert v.value == "a value"
660680

661681

682+
def test_fraction(ns_var: Var):
683+
assert Fraction('22/7') == lcompile('22/7')
684+
685+
662686
def test_inst(ns_var: Var):
663-
assert lcompile('#inst "2018-01-18T03:26:57.296-00:00"'
664-
) == dateparser.parse('2018-01-18T03:26:57.296-00:00')
687+
assert dateparser.parse('2018-01-18T03:26:57.296-00:00') == lcompile(
688+
'#inst "2018-01-18T03:26:57.296-00:00"')
665689

666690

667691
def test_regex(ns_var: Var):
668692
assert lcompile('#"\s"') == re.compile('\s')
669693

670694

671695
def test_uuid(ns_var: Var):
672-
assert lcompile('#uuid "0366f074-a8c5-4764-b340-6a5576afd2e8"'
673-
) == uuid.UUID('{0366f074-a8c5-4764-b340-6a5576afd2e8}')
696+
assert uuid.UUID('{0366f074-a8c5-4764-b340-6a5576afd2e8}') == lcompile(
697+
'#uuid "0366f074-a8c5-4764-b340-6a5576afd2e8"')

0 commit comments

Comments
 (0)