Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 858363d

Browse files
authored
Generics for ObservableDeferred (#10491)
Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to follow suit.
1 parent d0b294a commit 858363d

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

changelog.d/10491.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type annotations for `ObservableDeferred`.

synapse/notifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ def __init__(
111111
self.last_notified_token = current_token
112112
self.last_notified_ms = time_now_ms
113113

114-
with PreserveLoggingContext():
115-
self.notify_deferred = ObservableDeferred(defer.Deferred())
114+
self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
115+
defer.Deferred()
116+
)
116117

117118
def notify(
118119
self,

synapse/storage/persist_events.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ async def add_to_queue(
170170
end_item = queue[-1]
171171
else:
172172
# need to make a new queue item
173-
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
173+
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
174+
defer.Deferred(), consumeErrors=True
175+
)
174176

175177
end_item = _EventPersistQueueItem(
176178
events_and_contexts=[],

synapse/util/async_helpers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Awaitable,
2424
Callable,
2525
Dict,
26+
Generic,
2627
Hashable,
2728
Iterable,
2829
List,
@@ -39,6 +40,7 @@
3940
from twisted.internet.defer import CancelledError
4041
from twisted.internet.interfaces import IReactorTime
4142
from twisted.python import failure
43+
from twisted.python.failure import Failure
4244

4345
from synapse.logging.context import (
4446
PreserveLoggingContext,
@@ -52,7 +54,7 @@
5254
_T = TypeVar("_T")
5355

5456

55-
class ObservableDeferred:
57+
class ObservableDeferred(Generic[_T]):
5658
"""Wraps a deferred object so that we can add observer deferreds. These
5759
observer deferreds do not affect the callback chain of the original
5860
deferred.
@@ -70,7 +72,7 @@ class ObservableDeferred:
7072

7173
__slots__ = ["_deferred", "_observers", "_result"]
7274

73-
def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
75+
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
7476
object.__setattr__(self, "_deferred", deferred)
7577
object.__setattr__(self, "_result", None)
7678
object.__setattr__(self, "_observers", set())
@@ -115,15 +117,15 @@ def errback(f):
115117

116118
deferred.addCallbacks(callback, errback)
117119

118-
def observe(self) -> defer.Deferred:
120+
def observe(self) -> "defer.Deferred[_T]":
119121
"""Observe the underlying deferred.
120122
121123
This returns a brand new deferred that is resolved when the underlying
122124
deferred is resolved. Interacting with the returned deferred does not
123125
effect the underlying deferred.
124126
"""
125127
if not self._result:
126-
d: "defer.Deferred[Any]" = defer.Deferred()
128+
d: "defer.Deferred[_T]" = defer.Deferred()
127129

128130
def remove(r):
129131
self._observers.discard(d)
@@ -137,7 +139,7 @@ def remove(r):
137139
success, res = self._result
138140
return defer.succeed(res) if success else defer.fail(res)
139141

140-
def observers(self) -> List[defer.Deferred]:
142+
def observers(self) -> "List[defer.Deferred[_T]]":
141143
return self._observers
142144

143145
def has_called(self) -> bool:
@@ -146,7 +148,7 @@ def has_called(self) -> bool:
146148
def has_succeeded(self) -> bool:
147149
return self._result is not None and self._result[0] is True
148150

149-
def get_result(self) -> Any:
151+
def get_result(self) -> Union[_T, Failure]:
150152
return self._result[1]
151153

152154
def __getattr__(self, name: str) -> Any:

0 commit comments

Comments
 (0)