11from __future__ import annotations
22
3- from collections .abc import Callable
3+ from collections .abc import Callable , Hashable
44from functools import wraps
5- from typing import Generic , Optional , TypeVar
5+ from typing import Generic , TypeVar , Any , cast , overload , TYPE_CHECKING
6+ from typing_extensions import ParamSpec
67
7- T = TypeVar ("T" )
8+ if TYPE_CHECKING :
9+ from typing_extensions import TypeAlias
10+
11+ T = TypeVar ("T" , bound = Hashable )
812U = TypeVar ("U" )
13+ P = ParamSpec ("P" )
14+ R = TypeVar ("R" )
15+
16+ if TYPE_CHECKING :
17+ NodeKey : TypeAlias = T | None
18+ NodeValue : TypeAlias = U | None
19+ else :
20+ NodeKey = TypeVar ("NodeKey" , bound = Hashable )
21+ NodeValue = TypeVar ("NodeValue" )
922
1023
1124class DoubleLinkedListNode (Generic [T , U ]):
@@ -16,11 +29,11 @@ class DoubleLinkedListNode(Generic[T, U]):
1629 Node: key: 1, val: 1, has next: False, has prev: False
1730 """
1831
19- def __init__ (self , key : Optional [ T ] , val : Optional [ U ] ) -> None :
32+ def __init__ (self , key : NodeKey , val : NodeValue ) -> None :
2033 self .key = key
2134 self .val = val
22- self .next : Optional [ DoubleLinkedListNode [T , U ]] = None
23- self .prev : Optional [ DoubleLinkedListNode [T , U ]] = None
35+ self .next : DoubleLinkedListNode [T , U ] | None = None
36+ self .prev : DoubleLinkedListNode [T , U ] | None = None
2437
2538 def __repr__ (self ) -> str :
2639 return (
@@ -48,21 +61,20 @@ def __repr__(self) -> str:
4861 node = node .next
4962 rep .append (str (self .rear ))
5063 return ",\n " .join (rep )
51-
5264 def add (self , node : DoubleLinkedListNode [T , U ]) -> None :
5365 """Adds the given node to the end of the list (before rear)"""
5466 previous = self .rear .prev
5567 if previous is None :
5668 raise ValueError ("Invalid list state: rear.prev is None" )
57-
69+
5870 previous .next = node
5971 node .prev = previous
6072 self .rear .prev = node
6173 node .next = self .rear
6274
6375 def remove (
6476 self , node : DoubleLinkedListNode [T , U ]
65- ) -> Optional [ DoubleLinkedListNode [T , U ]] :
77+ ) -> DoubleLinkedListNode [T , U ] | None :
6678 """Removes and returns the given node from the list"""
6779 if node .prev is None or node .next is None :
6880 return None
@@ -97,7 +109,7 @@ def __repr__(self) -> str:
97109 def __contains__ (self , key : T ) -> bool :
98110 return key in self .cache
99111
100- def get (self , key : T ) -> Optional [ U ] :
112+ def get (self , key : T ) -> U | None :
101113 """Returns the value for the input key"""
102114 if key in self .cache :
103115 self .hits += 1
@@ -119,7 +131,6 @@ def put(self, key: T, value: U) -> None:
119131 node .val = value
120132 self .list .add (node )
121133 return
122-
123134 if self .num_keys >= self .capacity :
124135 first_node = self .list .head .next
125136 if first_node is None or first_node .key is None :
@@ -133,34 +144,51 @@ def put(self, key: T, value: U) -> None:
133144 self .list .add (new_node )
134145 self .num_keys += 1
135146
147+ @overload
136148 @classmethod
137149 def decorator (
138150 cls , size : int = 128
139- ) -> Callable [[Callable [..., U ]], Callable [..., U ]]:
140- """Decorator version of LRU Cache"""
151+ ) -> Callable [[Callable [P , R ]], Callable [P , R ]]:
152+ ...
153+
154+ @overload
155+ @classmethod
156+ def decorator (
157+ cls , func : Callable [P , R ]
158+ ) -> Callable [P , R ]:
159+ ...
141160
142- def decorator_func (func : Callable [..., U ]) -> Callable [..., U ]:
143- cache_instance = cls (size )
161+ @classmethod
162+ def decorator (
163+ cls , size : int | Callable [P , R ] = 128
164+ ) -> Callable [[Callable [P , R ]], Callable [P , R ]] | Callable [P , R ]:
165+ """Decorator version of LRU Cache"""
166+ if callable (size ):
167+ # Called without parentheses (@LRUCache.decorator)
168+ return cls .decorator ()(size )
169+
170+ def decorator_func (func : Callable [P , R ]) -> Callable [P , R ]:
171+ cache_instance = cls [Any , R ](size ) # type: ignore[valid-type]
144172
145173 @wraps (func )
146- def wrapper (* args : T , ** kwargs : T ) -> U :
147- key = (args , tuple (kwargs .items ()))
174+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> R :
175+ # Create normalized key
176+ sorted_kwargs = tuple (sorted (kwargs .items (), key = lambda x : x [0 ]))
177+ key = (args , sorted_kwargs )
148178 result = cache_instance .get (key )
149179 if result is None :
150180 result = func (* args , ** kwargs )
151181 cache_instance .put (key , result )
152182 return result
153-
154- def cache_info () -> LRUCache :
183+ def cache_info () -> LRUCache [Any , R ]: # type: ignore[valid-type]
155184 return cache_instance
156185
157- setattr ( wrapper , " cache_info" , cache_info )
186+ wrapper . cache_info = cache_info # Direct assignment
158187 return wrapper
159188
160189 return decorator_func
161190
162191
163192if __name__ == "__main__" :
164193 import doctest
165-
166194 doctest .testmod ()
0 commit comments