|
6 | 6 | import uuid |
7 | 7 | from datetime import datetime |
8 | 8 | from datetime import timezone |
| 9 | +from threading import Lock |
9 | 10 | from typing import Any |
10 | 11 | from typing import cast |
11 | 12 | from typing import Generic |
@@ -53,8 +54,42 @@ def wrapped(cls: Any, value: T) -> R: |
53 | 54 | return wrapped |
54 | 55 |
|
55 | 56 |
|
| 57 | +class ValueProvider: |
| 58 | + def __init__(self) -> None: |
| 59 | + self.lock = Lock() |
| 60 | + self.prev_timestamp = constants.MIN_TIMESTAMP |
| 61 | + self.prev_randomness = constants.MIN_RANDOMNESS |
| 62 | + |
| 63 | + def timestamp(self, value: float | None = None) -> int: |
| 64 | + if value is None: |
| 65 | + value = time.time_ns() // constants.NANOSECS_IN_MILLISECS |
| 66 | + elif isinstance(value, float): |
| 67 | + value = int(value * constants.MILLISECS_IN_SECS) |
| 68 | + if value > constants.MAX_TIMESTAMP: |
| 69 | + raise ValueError("Value exceeds maximum possible timestamp") |
| 70 | + return value |
| 71 | + |
| 72 | + def randomness(self) -> bytes: |
| 73 | + with self.lock: |
| 74 | + current_timestamp = self.timestamp() |
| 75 | + if current_timestamp == self.prev_timestamp: |
| 76 | + if self.prev_randomness == constants.MAX_RANDOMNESS: |
| 77 | + raise ValueError("Randomness within same millisecond exhausted") |
| 78 | + randomness = (int.from_bytes(self.prev_randomness) + 1).to_bytes( |
| 79 | + constants.RANDOMNESS_LEN, byteorder="big" |
| 80 | + ) |
| 81 | + else: |
| 82 | + randomness = os.urandom(constants.RANDOMNESS_LEN) |
| 83 | + |
| 84 | + self.prev_randomness = randomness |
| 85 | + self.prev_timestamp = current_timestamp |
| 86 | + return randomness |
| 87 | + |
| 88 | + |
56 | 89 | @functools.total_ordering |
57 | 90 | class ULID: |
| 91 | + provider = ValueProvider() |
| 92 | + |
58 | 93 | """The :class:`ULID` object consists of a timestamp part of 48 bits and of 80 random bits. |
59 | 94 |
|
60 | 95 | .. code-block:: text |
@@ -82,9 +117,7 @@ class ULID: |
82 | 117 | def __init__(self, value: bytes | None = None) -> None: |
83 | 118 | if value is not None and len(value) != constants.BYTES_LEN: |
84 | 119 | raise ValueError("ULID has to be exactly 16 bytes long.") |
85 | | - self.bytes: bytes = ( |
86 | | - value or ULID.from_timestamp(time.time_ns() // constants.NANOSECS_IN_MILLISECS).bytes |
87 | | - ) |
| 120 | + self.bytes: bytes = value or ULID.from_timestamp(self.provider.timestamp()).bytes |
88 | 121 |
|
89 | 122 | @classmethod |
90 | 123 | @validate_type(datetime) |
@@ -113,10 +146,8 @@ def from_timestamp(cls, value: float) -> Self: |
113 | 146 | >>> ULID.from_timestamp(time.time()) |
114 | 147 | ULID(01E75QWN5HKQ0JAVX9FG1K4YP4) |
115 | 148 | """ |
116 | | - if isinstance(value, float): |
117 | | - value = int(value * constants.MILLISECS_IN_SECS) |
118 | | - timestamp = int.to_bytes(value, constants.TIMESTAMP_LEN, "big") |
119 | | - randomness = os.urandom(constants.RANDOMNESS_LEN) |
| 149 | + timestamp = int.to_bytes(cls.provider.timestamp(value), constants.TIMESTAMP_LEN, "big") |
| 150 | + randomness = cls.provider.randomness() |
120 | 151 | return cls.from_bytes(timestamp + randomness) |
121 | 152 |
|
122 | 153 | @classmethod |
|
0 commit comments