Skip to content

Commit e92bb19

Browse files
authored
Update lru_cache.py
1 parent 824f2bc commit e92bb19

File tree

1 file changed

+49
-21
lines changed

1 file changed

+49
-21
lines changed

other/lru_cache.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Hashable
44
from functools import wraps
5-
from typing import Generic, Optional, TypeVar
5+
from typing import Generic, TypeVar, Any, cast, overload, TYPE_CHECKING
6+
from typing_extensions import ParamSpec
67

7-
T = TypeVar("T")
8+
if TYPE_CHECKING:
9+
from typing_extensions import TypeAlias
10+
11+
T = TypeVar("T", bound=Hashable)
812
U = TypeVar("U")
13+
P = ParamSpec("P")
14+
R = TypeVar("R")
15+
16+
if TYPE_CHECKING:
17+
NodeKey: TypeAlias = T | None
18+
NodeValue: TypeAlias = U | None
19+
else:
20+
NodeKey = TypeVar("NodeKey", bound=Hashable)
21+
NodeValue = TypeVar("NodeValue")
922

1023

1124
class DoubleLinkedListNode(Generic[T, U]):
@@ -16,11 +29,11 @@ class DoubleLinkedListNode(Generic[T, U]):
1629
Node: key: 1, val: 1, has next: False, has prev: False
1730
"""
1831

19-
def __init__(self, key: Optional[T], val: Optional[U]) -> None:
32+
def __init__(self, key: NodeKey, val: NodeValue) -> None:
2033
self.key = key
2134
self.val = val
22-
self.next: Optional[DoubleLinkedListNode[T, U]] = None
23-
self.prev: Optional[DoubleLinkedListNode[T, U]] = None
35+
self.next: DoubleLinkedListNode[T, U] | None = None
36+
self.prev: DoubleLinkedListNode[T, U] | None = None
2437

2538
def __repr__(self) -> str:
2639
return (
@@ -48,21 +61,20 @@ def __repr__(self) -> str:
4861
node = node.next
4962
rep.append(str(self.rear))
5063
return ",\n ".join(rep)
51-
5264
def add(self, node: DoubleLinkedListNode[T, U]) -> None:
5365
"""Adds the given node to the end of the list (before rear)"""
5466
previous = self.rear.prev
5567
if previous is None:
5668
raise ValueError("Invalid list state: rear.prev is None")
57-
69+
5870
previous.next = node
5971
node.prev = previous
6072
self.rear.prev = node
6173
node.next = self.rear
6274

6375
def remove(
6476
self, node: DoubleLinkedListNode[T, U]
65-
) -> Optional[DoubleLinkedListNode[T, U]]:
77+
) -> DoubleLinkedListNode[T, U] | None:
6678
"""Removes and returns the given node from the list"""
6779
if node.prev is None or node.next is None:
6880
return None
@@ -97,7 +109,7 @@ def __repr__(self) -> str:
97109
def __contains__(self, key: T) -> bool:
98110
return key in self.cache
99111

100-
def get(self, key: T) -> Optional[U]:
112+
def get(self, key: T) -> U | None:
101113
"""Returns the value for the input key"""
102114
if key in self.cache:
103115
self.hits += 1
@@ -119,7 +131,6 @@ def put(self, key: T, value: U) -> None:
119131
node.val = value
120132
self.list.add(node)
121133
return
122-
123134
if self.num_keys >= self.capacity:
124135
first_node = self.list.head.next
125136
if first_node is None or first_node.key is None:
@@ -133,34 +144,51 @@ def put(self, key: T, value: U) -> None:
133144
self.list.add(new_node)
134145
self.num_keys += 1
135146

147+
@overload
136148
@classmethod
137149
def decorator(
138150
cls, size: int = 128
139-
) -> Callable[[Callable[..., U]], Callable[..., U]]:
140-
"""Decorator version of LRU Cache"""
151+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
152+
...
153+
154+
@overload
155+
@classmethod
156+
def decorator(
157+
cls, func: Callable[P, R]
158+
) -> Callable[P, R]:
159+
...
141160

142-
def decorator_func(func: Callable[..., U]) -> Callable[..., U]:
143-
cache_instance = cls(size)
161+
@classmethod
162+
def decorator(
163+
cls, size: int | Callable[P, R] = 128
164+
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
165+
"""Decorator version of LRU Cache"""
166+
if callable(size):
167+
# Called without parentheses (@LRUCache.decorator)
168+
return cls.decorator()(size)
169+
170+
def decorator_func(func: Callable[P, R]) -> Callable[P, R]:
171+
cache_instance = cls[Any, R](size) # type: ignore[valid-type]
144172

145173
@wraps(func)
146-
def wrapper(*args: T, **kwargs: T) -> U:
147-
key = (args, tuple(kwargs.items()))
174+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
175+
# Create normalized key
176+
sorted_kwargs = tuple(sorted(kwargs.items(), key=lambda x: x[0]))
177+
key = (args, sorted_kwargs)
148178
result = cache_instance.get(key)
149179
if result is None:
150180
result = func(*args, **kwargs)
151181
cache_instance.put(key, result)
152182
return result
153-
154-
def cache_info() -> LRUCache:
183+
def cache_info() -> LRUCache[Any, R]: # type: ignore[valid-type]
155184
return cache_instance
156185

157-
setattr(wrapper, "cache_info", cache_info)
186+
wrapper.cache_info = cache_info # Direct assignment
158187
return wrapper
159188

160189
return decorator_func
161190

162191

163192
if __name__ == "__main__":
164193
import doctest
165-
166194
doctest.testmod()

0 commit comments

Comments
 (0)