diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..3fdc7ef --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +# fractional_indexing/__init__.py + +from fractional_indexing import ( + generate_key_between, + generate_n_keys_between +) +from exceptions import OrderKeyError + +__all__ = [ + 'generate_key_between', + 'generate_n_keys_between', + 'OrderKeyError' +] diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 0000000..c6b5e55 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,5 @@ +# fractional_indexing/exceptions.py + +class OrderKeyError(Exception): + """Custom error for invalid order keys.""" + pass diff --git a/fractional_indexing.py b/fractional_indexing.py index 7bbc722..4d6ffd3 100644 --- a/fractional_indexing.py +++ b/fractional_indexing.py @@ -1,285 +1,131 @@ -""" -Provides functions for generating ordering strings +# fractional_indexing/fractional_indexing.py -. - - - - -""" from math import floor from typing import Optional, List -import decimal - - -__version__ = '0.1.3' -__licence__ = 'CC0 1.0 Universal' - -BASE_62_DIGITS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - -class FIError(Exception): - pass +from exceptions import OrderKeyError +from utils import BASE_62_DIGITS, validate_order_key +from utils import ( + get_integer_part, + find_middle_key, + decrement_integer, + increment_integer, +) -def midpoint(a: str, b: Optional[str], digits: str) -> str: - """ - `a` may be empty string, `b` is null or non-empty string. - `a < b` lexicographically if `b` is non-null. - no trailing zeros allowed. - digits is a string such as '0123456789' for base 10. Digits must be in - ascending character code order! - """ +# GENERATE KEY BETWEEN HANDLERS AND FUNCTION +def handle_end_key_only_case(end_key: str, digits: str) -> str: + """Handle the case when only `end_key` is provided.""" zero = digits[0] - if b is not None and a >= b: - raise FIError(f'{a} >= {b}') - if (a and a[-1]) == zero or (b is not None and b[-1] == zero): - raise FIError('trailing zero') - if b: - # remove longest common prefix. pad `a` with 0s as we - # go. note that we don't need to pad `b`, because it can't - # end before `a` while traversing the common prefix. - n = 0 - for x, y in zip(a.ljust(len(b), zero), b): - if x == y: - n += 1 - continue - break + integer_part = get_integer_part(end_key) + fractional_part = end_key[len(integer_part):] + if integer_part == 'A' + (zero * 26): + return integer_part + find_middle_key('', fractional_part, digits) + if integer_part < end_key: + return integer_part + decremented = decrement_integer(integer_part, digits) + if decremented is None: + raise OrderKeyError('Cannot decrement anymore') + return decremented - if n > 0: - return b[:n] + midpoint(a[n:], b[n:], digits) - # first digits (or lack of digit) are different - try: - digit_a = digits.index(a[0]) if a else 0 - except IndexError: - digit_a = -1 - try: - digit_b = digits.index(b[0]) if b is not None else len(digits) - except IndexError: - digit_b = -1 +def handle_start_key_only_case(start_key: str, digits: str) -> str: + """Handle the case when only `start_key` is provided.""" + integer_part = get_integer_part(start_key) + fractional_part = start_key[len(integer_part):] + incremented = increment_integer(integer_part, digits) + return integer_part + find_middle_key(fractional_part, None, digits) if incremented is None else incremented - if digit_b - digit_a > 1: - min_digit = round_half_up(0.5 * (digit_a + digit_b)) - return digits[min_digit] - else: - if b is not None and len(b) > 1: - return b[:1] - else: - # `b` is null or has length 1 (a single digit). - # the first digit of `a` is the previous digit to `b`, - # or 9 if `b` is null. - # given, for example, midpoint('49', '5'), return - # '4' + midpoint('9', null), which will become - # '4' + '9' + midpoint('', null), which is '495' - return digits[digit_a] + midpoint(a[1:], None, digits) +def handle_both_keys_case(start_key: str, end_key: str, digits: str) -> str: + """Handle the case when both `start_key` and `end_key` are provided.""" + start_int_part = get_integer_part(start_key) + start_frac_part = start_key[len(start_int_part):] + end_int_part = get_integer_part(end_key) + end_frac_part = end_key[len(end_int_part):] -def validate_integer(i: str): - if len(i) != get_integer_length(i[0]): - raise FIError(f'invalid integer part of order key: {i}') + if start_int_part == end_int_part: + return start_int_part + find_middle_key(start_frac_part, end_frac_part, digits) + incremented = increment_integer(start_int_part, digits) -def get_integer_length(head): - if 'a' <= head <= 'z': - return ord(head) - ord('a') + 2 - elif 'A' <= head <= 'Z': - return ord('Z') - ord(head[0]) + 2 - raise FIError('invalid order key head: ' + head) + if incremented is None: + raise OrderKeyError('Cannot increment anymore') + if incremented < end_key: + return incremented -def get_integer_part(key: str) -> str: - integer_part_length = get_integer_length(key[0]) - if integer_part_length > len(key): - raise FIError(f'invalid order key: {key}') - return key[:integer_part_length] + return start_int_part + find_middle_key(start_frac_part, None, digits) -def validate_order_key(key: str, digits=BASE_62_DIGITS): +def generate_key_between(start_key: Optional[str], end_key: Optional[str], digits: str = BASE_62_DIGITS) -> str: + """ + Generate an order key that lies between `start_key` and `end_key`. + If both are None, returns the first possible key. + """ zero = digits[0] - smallest = 'A' + (zero * 26) - if key == smallest: - raise FIError(f'invalid order key: {key}') - # get_integer_part() will throw if the first character is bad, - # or the key is too short. we'd call it to check these things - # even if we didn't need the result - i = get_integer_part(key) - f = key[len(i):] - if f and f[-1] == zero: - raise FIError(f'invalid order key: {key}') + if start_key is not None: + validate_order_key(start_key, digits) + if end_key is not None: + validate_order_key(end_key, digits) -def increment_integer(x: str, digits: str) -> Optional[str]: - zero = digits[0] - validate_integer(x) - head, *digs = x - carry = True - for i in reversed(range(len(digs))): - d = digits.index(digs[i]) + 1 - if d == len(digits): - digs[i] = zero - else: - digs[i] = digits[d] - carry = False - break - if carry: - if head == 'Z': + if start_key is not None and end_key is not None and start_key >= end_key: + raise OrderKeyError(f'{start_key} >= {end_key}') + + if start_key is None: + if end_key is None: return 'a' + zero - elif head == 'z': - return None - h = chr(ord(head[0]) + 1) - if h > 'a': - digs.append(zero) - else: - digs.pop() - return h + ''.join(digs) - else: - return head + ''.join(digs) + return handle_end_key_only_case(end_key, digits) + if end_key is None: + return handle_start_key_only_case(start_key, digits) -def decrement_integer(x, digits): - validate_integer(x) - head, *digs = x - borrow = True - for i in reversed(range(len(digs))): + return handle_both_keys_case(start_key, end_key, digits) - try: - index = digits.index(digs[i]) - except IndexError: - index = -1 - d = index - 1 - if d == -1: - digs[i] = digits[-1] - else: - digs[i] = digits[d] - borrow = False - break - if borrow: - if head == 'a': - return 'Z' + digits[-1] - if head == 'A': - return None - h = chr(ord(head[0]) - 1) - if h < 'Z': - digs.append(digits[-1]) - else: - digs.pop() - return h + ''.join(digs) - else: - return head + ''.join(digs) +# GENERATE N KEYS BETWEEN HANDLERS AND FUNCTION +def handle_generate_n_keys_with_end_none(start_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `end_key` as None.""" + current_key = generate_key_between(start_key, None, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(current_key, None, digits) + result.append(current_key) + return result -def generate_key_between(a: Optional[str], b: Optional[str], digits=BASE_62_DIGITS) -> str: - """ - `a` is an order key or null (START). - `b` is an order key or null (END). - `a < b` lexicographically if both are non-null. - digits is a string such as '0123456789' for base 10. Digits must be in - ascending character code order! +def handle_generate_n_keys_with_start_none(end_key: Optional[str], number_of_keys: int, digits: str) -> List[str]: + """Handle case when generating keys with `start_key` as None.""" + current_key = generate_key_between(None, end_key, digits) + result = [current_key] + for _ in range(number_of_keys - 1): + current_key = generate_key_between(None, current_key, digits) + result.append(current_key) + return list(reversed(result)) +def generate_n_keys_between(start_key: Optional[str], end_key: Optional[str], number_of_keys: int, digits: str = BASE_62_DIGITS) -> List[str]: """ - zero = digits[0] - if a is not None: - validate_order_key(a, digits=digits) - if b is not None: - validate_order_key(b, digits=digits) - if a is not None and b is not None and a >= b: - raise FIError(f'{a} >= {b}') - - if a is None: - if b is None: - return 'a' + zero - ib = get_integer_part(b) - fb = b[len(ib):] - if ib == 'A' + (zero * 26): - return ib + midpoint('', fb, digits) - if ib < b: - return ib - res = decrement_integer(ib, digits) - if res is None: - raise FIError('cannot decrement any more') - return res - - if b is None: - ia = get_integer_part(a) - fa = a[len(ia):] - i = increment_integer(ia, digits) - return ia + midpoint(fa, None, digits) if i is None else i - - ia = get_integer_part(a) - fa = a[len(ia):] - ib = get_integer_part(b) - fb = b[len(ib):] - if ia == ib: - return ia + midpoint(fa, fb, digits) - i = increment_integer(ia, digits) - if i is None: - raise FIError('cannot increment any more') + Generate `number_of_keys` distinct order keys between `start_key` and `end_key`. + """ + if number_of_keys == 0: + return [] - if i < b: - return i + if number_of_keys == 1: + return [generate_key_between(start_key, end_key, digits)] - return ia + midpoint(fa, None, digits) + if end_key is None: + return handle_generate_n_keys_with_end_none(start_key, number_of_keys, digits) + if start_key is None: + return handle_generate_n_keys_with_start_none(end_key, number_of_keys, digits) -def generate_n_keys_between(a: Optional[str], b: Optional[str], n: int, digits=BASE_62_DIGITS) -> List[str]: - """ - same preconditions as generate_keys_between(). - n >= 0. - Returns an array of n distinct keys in sorted order. - If a and b are both null, returns [a0, a1, ...] - If one or the other is null, returns consecutive "integer" - keys. Otherwise, returns relatively short keys between + mid_index = floor(number_of_keys / 2) + middle_key = generate_key_between(start_key, end_key, digits) - """ - if n == 0: - return [] - if n == 1: - return [generate_key_between(a, b, digits)] - if b is None: - c = generate_key_between(a, b, digits) - result = [c] - for i in range(n - 1): - c = generate_key_between(c, b, digits) - result.append(c) - return result - - if a is None: - c = generate_key_between(a, b, digits) - result = [c] - for i in range(n - 1): - c = generate_key_between(a, c, digits) - result.append(c) - return list(reversed(result)) - - mid = floor(n / 2) - c = generate_key_between(a, b, digits) return [ - *generate_n_keys_between(a, c, mid, digits), - c, - *generate_n_keys_between(c, b, n - mid - 1, digits) + *generate_n_keys_between(start_key, middle_key, mid_index, digits), + middle_key, + *generate_n_keys_between(middle_key, end_key, number_of_keys - mid_index - 1, digits) ] - - -def round_half_up(n: float) -> int: - """ - >>> round_half_up(0.4) - 0 - >>> round_half_up(0.8) - 1 - >>> round_half_up(0.5) - 1 - >>> round_half_up(1.5) - 2 - >>> round_half_up(2.5) - 3 - """ - return int( - decimal.Decimal(str(n)).quantize( - decimal.Decimal('1'), - rounding=decimal.ROUND_HALF_UP - ) - ) diff --git a/tests.py b/tests.py index 44fec6b..b1d1f9c 100644 --- a/tests.py +++ b/tests.py @@ -2,8 +2,12 @@ import pytest -from fractional_indexing import FIError, generate_key_between, generate_n_keys_between, validate_order_key - +from fractional_indexing import ( + OrderKeyError, + generate_key_between, + generate_n_keys_between, +) +from utils import validate_order_key BASE_95_DIGITS = ' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~' @@ -28,20 +32,20 @@ ('Zz', 'a01', 'a0'), (None, 'a0V', 'a0'), (None, 'b999', 'b99'), - (None, 'A00000000000000000000000000', FIError('invalid order key: A00000000000000000000000000')), + (None, 'A00000000000000000000000000', OrderKeyError('invalid order key: A00000000000000000000000000')), (None, 'A000000000000000000000000001', 'A000000000000000000000000000V'), ('zzzzzzzzzzzzzzzzzzzzzzzzzzy', None, 'zzzzzzzzzzzzzzzzzzzzzzzzzzz'), ('zzzzzzzzzzzzzzzzzzzzzzzzzzz', None, 'zzzzzzzzzzzzzzzzzzzzzzzzzzzV'), - ('a00', None, FIError('invalid order key: a00')), - ('a00', 'a1', FIError('invalid order key: a00')), - ('0', '1', FIError('invalid order key head: 0')), - ('a1', 'a0', FIError('a1 >= a0')), + ('a00', None, OrderKeyError('invalid order key: a00')), + ('a00', 'a1', OrderKeyError('invalid order key: a00')), + ('0', '1', OrderKeyError('invalid order key head: 0')), + ('a1', 'a0', OrderKeyError('a1 >= a0')), ]) -def test_generate_key_between(a: Optional[str], b: Optional[str], expected: str) -> None: - if isinstance(expected, FIError): - with pytest.raises(FIError) as e: +def test_generate_key_between(a: Optional[str], b: Optional[str], expected) -> None: + if isinstance(expected, OrderKeyError): + with pytest.raises(OrderKeyError) as e: generate_key_between(a, b) - assert e.value.args[0] == expected.args[0] + assert e.value.args[0].lower() == expected.args[0].lower() else: act = generate_key_between(a, b) print(f'exp: {expected}') @@ -50,6 +54,7 @@ def test_generate_key_between(a: Optional[str], b: Optional[str], expected: str) assert act == expected + @pytest.mark.parametrize(['a', 'b', 'n', 'expected'], [ (None, None, 5, 'a0 a1 a2 a3 a4'), ('a4', None, 10, 'a5 a6 a7 a8 a9 b00 b01 b02 b03 b04'), @@ -72,25 +77,25 @@ def test_generate_n_keys_between(a: Optional[str], b: Optional[str], n: int, exp (None, None, 'a '), ('a ', None, 'a!'), (None, 'a ', 'Z~'), - ('a0 ', 'a0!', FIError('invalid order key: a0 ')), + ('a0 ', 'a0!', OrderKeyError('invalid order key: a0 ')), (None, 'A 0', 'A ('), ('a~', None, 'b '), ('Z~', None, 'a '), - ('b ', None, FIError('invalid order key: b ')), + ('b ', None, OrderKeyError('invalid order key: b ')), ('a0', 'a0V', 'a0;'), ('a 1', 'a 2', 'a 1P'), - (None, 'A ', FIError('invalid order key: A ')), + (None, 'A ', OrderKeyError('invalid order key: A ')), ]) def test_base95_digits(a: Optional[str], b: Optional[str], expected: str) -> None: kwargs = { - 'a': a, - 'b': b, + 'start_key': a, + 'end_key': b, 'digits': BASE_95_DIGITS, } - if isinstance(expected, FIError): - with pytest.raises(FIError) as e: + if isinstance(expected, OrderKeyError): + with pytest.raises(OrderKeyError) as e: generate_key_between(**kwargs) - assert e.value.args[0] == expected.args[0] + assert e.value.args[0].lower() == expected.args[0].lower() else: act = generate_key_between(**kwargs) print() @@ -124,31 +129,28 @@ def test_readme_examples_single_key(): def test_readme_examples_multiple_keys(): # Insert 3 at the beginning - keys = generate_n_keys_between(None, None, n=3) + keys = generate_n_keys_between(None, None, number_of_keys=3) assert keys == ['a0', 'a1', 'a2'] # Insert 3 after 1st - keys = generate_n_keys_between('a0', None, n=3) + keys = generate_n_keys_between('a0', None, number_of_keys=3) assert keys == ['a1', 'a2', 'a3'] # Insert 3 before 1st - keys = generate_n_keys_between(None, 'a0', n=3) + keys = generate_n_keys_between(None, 'a0', number_of_keys=3) assert keys == ['Zx', 'Zy', 'Zz'] # Insert 3 in between 2nd and 3rd. Midpoint - keys = generate_n_keys_between('a1', 'a2', n=3) + keys = generate_n_keys_between('a1', 'a2', number_of_keys=3) assert keys == ['a1G', 'a1V', 'a1l'] def test_readme_examples_validate_order_key(): - from fractional_indexing import validate_order_key, FIError - validate_order_key('a0') - try: + with pytest.raises(OrderKeyError) as e: validate_order_key('foo') - except FIError as e: - print(e) # fractional_indexing.FIError: invalid order key: foo + assert str(e.value).lower() == 'invalid order key: foo'.lower() def test_readme_examples_custom_base(): @@ -156,4 +158,4 @@ def test_readme_examples_custom_base(): assert generate_key_between(None, None, digits=BASE_95_DIGITS) == 'a ' assert generate_key_between('a ', None, digits=BASE_95_DIGITS) == 'a!' assert generate_key_between(None, 'a ', digits=BASE_95_DIGITS) == 'Z~' - assert generate_n_keys_between('a ', 'a!', n=3, digits=BASE_95_DIGITS) == ['a 8', 'a P', 'a h'] + assert generate_n_keys_between('a ', 'a!', number_of_keys=3, digits=BASE_95_DIGITS) == ['a 8', 'a P', 'a h'] diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b9c3442 --- /dev/null +++ b/utils.py @@ -0,0 +1,161 @@ +# fractional_indexing/utils.py + +import decimal +from typing import Optional + +from exceptions import OrderKeyError + +BASE_62_DIGITS: str = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + + +def round_half_up(value: float) -> int: + """Round a float to the nearest integer, rounding halves up.""" + return int( + decimal.Decimal(str(value)).quantize( + decimal.Decimal('1'), + rounding=decimal.ROUND_HALF_UP + ) + ) + + +def validate_integer(order_key: str) -> None: + """Validate that the length of the integer part of the order key is correct.""" + if len(order_key) != get_integer_length(order_key[0]): + raise OrderKeyError(f'Invalid integer part of order key: {order_key}') + + +def get_integer_length(first_char: str) -> int: + """Return the length of the integer part based on the first character.""" + if 'a' <= first_char <= 'z': + return ord(first_char) - ord('a') + 2 + elif 'A' <= first_char <= 'Z': + return ord('Z') - ord(first_char) + 2 + raise OrderKeyError('Invalid order key head: ' + first_char) + + +def get_integer_part(order_key: str) -> str: + """Extract the integer part of the order key.""" + integer_part_length = get_integer_length(order_key[0]) + if integer_part_length > len(order_key): + raise OrderKeyError(f'Invalid order key: {order_key}') + return order_key[:integer_part_length] + + +def validate_order_key(order_key: str, digits: str = BASE_62_DIGITS) -> None: + """Check the validity of an order key.""" + zero = digits[0] + smallest_valid_key = 'A' + (zero * 26) + + if order_key == smallest_valid_key: + raise OrderKeyError(f'Invalid order key: {order_key}') + + integer_part = get_integer_part(order_key) + fractional_part = order_key[len(integer_part):] + + if fractional_part and fractional_part[-1] == zero: + raise OrderKeyError(f'Invalid order key: {order_key}') + + +def find_middle_key(start_key: str, end_key: Optional[str], digits: str) -> str: + """ + Calculate the midpoint between two order keys. + `start_key` must be lexicographically less than `end_key`. + No trailing zeros allowed in the order key. + """ + zero = digits[0] + + if end_key is not None and start_key >= end_key: + raise OrderKeyError(f'{start_key} >= {end_key}') + + if (start_key and start_key[-1] == zero) or (end_key and end_key[-1] == zero): + raise OrderKeyError('Trailing zero in order key') + + if end_key: + common_prefix_len = 0 + for char_start, char_end in zip(start_key.ljust(len(end_key), zero), end_key): + if char_start == char_end: + common_prefix_len += 1 + continue + break + + if common_prefix_len > 0: + return end_key[:common_prefix_len] + find_middle_key( + start_key[common_prefix_len:], end_key[common_prefix_len:], digits + ) + + # Different first digits or lack of digit + digit_a = digits.index(start_key[0]) if start_key else 0 + digit_b = digits.index(end_key[0]) if end_key else len(digits) + + if digit_b - digit_a > 1: + min_digit = round_half_up(0.5 * (digit_a + digit_b)) + return digits[min_digit] + + if end_key and len(end_key) > 1: + return end_key[:1] + + return digits[digit_a] + find_middle_key(start_key[1:], None, digits) + + +def increment_integer(integer_str: str, digits: str) -> Optional[str]: + """Increment the integer part of the order key.""" + zero = digits[0] + validate_integer(integer_str) + + head, *digits_list = integer_str + has_carry_over = True + + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) + 1 + if current_digit == len(digits): + digits_list[i] = zero + else: + digits_list[i] = digits[current_digit] + has_carry_over = False + break + + if has_carry_over: + if head == 'Z': + return 'a' + zero + if head == 'z': + return None + next_head = chr(ord(head) + 1) + if next_head > 'a': + digits_list.append(zero) + else: + digits_list.pop() + return next_head + ''.join(digits_list) + + return head + ''.join(digits_list) + + +def decrement_integer(integer_str: str, digits: str) -> Optional[str]: + """Decrement the integer part of the order key.""" + validate_integer(integer_str) + + head, *digits_list = integer_str + requires_borrow = True + + for i in reversed(range(len(digits_list))): + current_digit = digits.index(digits_list[i]) - 1 + + if current_digit == -1: + digits_list[i] = digits[-1] + else: + digits_list[i] = digits[current_digit] + requires_borrow = False + break + + if requires_borrow: + if head == 'a': + return 'Z' + digits[-1] + if head == 'A': + return None + next_head = chr(ord(head) - 1) + if next_head < 'Z': + digits_list.append(digits[-1]) + else: + digits_list.pop() + return next_head + ''.join(digits_list) + + return head + ''.join(digits_list)