1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15- from typing import Any , Callable , Dict , List , Optional , Tuple , Type
15+ from typing import Any , Callable , Dict , List , Optional , Tuple
1616
17- from twisted .internet .interfaces import IConsumer , IPullProducer , IReactorTime
1817from twisted .internet .protocol import Protocol
19- from twisted .internet .task import LoopingCall
20- from twisted .web .http import HTTPChannel
2118from twisted .web .resource import Resource
22- from twisted .web .server import Request , Site
2319
2420from synapse .app .generic_worker import GenericWorkerServer
2521from synapse .http .server import JsonResource
3329 ServerReplicationStreamProtocol ,
3430)
3531from synapse .server import HomeServer
36- from synapse .util import Clock
3732
3833from tests import unittest
3934from tests .server import FakeTransport
@@ -154,7 +149,19 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
154149 client_protocol = client_factory .buildProtocol (None )
155150
156151 # Set up the server side protocol
157- channel = _PushHTTPChannel (self .reactor , SynapseRequest , self .site )
152+ channel = self .site .buildProtocol (None )
153+
154+ # hook into the channel's request factory so that we can keep a record
155+ # of the requests
156+ requests : List [SynapseRequest ] = []
157+ real_request_factory = channel .requestFactory
158+
159+ def request_factory (* args , ** kwargs ):
160+ request = real_request_factory (* args , ** kwargs )
161+ requests .append (request )
162+ return request
163+
164+ channel .requestFactory = request_factory
158165
159166 # Connect client to server and vice versa.
160167 client_to_server_transport = FakeTransport (
@@ -176,7 +183,10 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
176183 server_to_client_transport .loseConnection ()
177184 client_to_server_transport .loseConnection ()
178185
179- return channel .request
186+ # there should have been exactly one request
187+ self .assertEqual (len (requests ), 1 )
188+
189+ return requests [0 ]
180190
181191 def assert_request_is_get_repl_stream_updates (
182192 self , request : SynapseRequest , stream_name : str
@@ -387,7 +397,7 @@ def _handle_http_replication_attempt(self, hs, repl_port):
387397 client_protocol = client_factory .buildProtocol (None )
388398
389399 # Set up the server side protocol
390- channel = _PushHTTPChannel ( self .reactor , SynapseRequest , self . _hs_to_site [hs ])
400+ channel = self ._hs_to_site [hs ]. buildProtocol ( None )
391401
392402 # Connect client to server and vice versa.
393403 client_to_server_transport = FakeTransport (
@@ -445,112 +455,6 @@ async def on_rdata(self, stream_name, instance_name, token, rows):
445455 self .received_rdata_rows .append ((stream_name , token , r ))
446456
447457
448- class _PushHTTPChannel (HTTPChannel ):
449- """A HTTPChannel that wraps pull producers to push producers.
450-
451- This is a hack to get around the fact that HTTPChannel transparently wraps a
452- pull producer (which is what Synapse uses to reply to requests) with
453- `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
454- uses the standard reactor rather than letting us use our test reactor, which
455- makes it very hard to test.
456- """
457-
458- def __init__ (
459- self , reactor : IReactorTime , request_factory : Type [Request ], site : Site
460- ):
461- super ().__init__ ()
462- self .reactor = reactor
463- self .requestFactory = request_factory
464- self .site = site
465-
466- self ._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
467-
468- def registerProducer (self , producer , streaming ):
469- # Convert pull producers to push producer.
470- if not streaming :
471- self ._pull_to_push_producer = _PullToPushProducer (
472- self .reactor , producer , self
473- )
474- producer = self ._pull_to_push_producer
475-
476- super ().registerProducer (producer , True )
477-
478- def unregisterProducer (self ):
479- if self ._pull_to_push_producer :
480- # We need to manually stop the _PullToPushProducer.
481- self ._pull_to_push_producer .stop ()
482-
483- def checkPersistence (self , request , version ):
484- """Check whether the connection can be re-used"""
485- # We hijack this to always say no for ease of wiring stuff up in
486- # `handle_http_replication_attempt`.
487- request .responseHeaders .setRawHeaders (b"connection" , [b"close" ])
488- return False
489-
490- def requestDone (self , request ):
491- # Store the request for inspection.
492- self .request = request
493- super ().requestDone (request )
494-
495-
496- class _PullToPushProducer :
497- """A push producer that wraps a pull producer."""
498-
499- def __init__ (
500- self , reactor : IReactorTime , producer : IPullProducer , consumer : IConsumer
501- ):
502- self ._clock = Clock (reactor )
503- self ._producer = producer
504- self ._consumer = consumer
505-
506- # While running we use a looping call with a zero delay to call
507- # resumeProducing on given producer.
508- self ._looping_call = None # type: Optional[LoopingCall]
509-
510- # We start writing next reactor tick.
511- self ._start_loop ()
512-
513- def _start_loop (self ):
514- """Start the looping call to"""
515-
516- if not self ._looping_call :
517- # Start a looping call which runs every tick.
518- self ._looping_call = self ._clock .looping_call (self ._run_once , 0 )
519-
520- def stop (self ):
521- """Stops calling resumeProducing."""
522- if self ._looping_call :
523- self ._looping_call .stop ()
524- self ._looping_call = None
525-
526- def pauseProducing (self ):
527- """Implements IPushProducer"""
528- self .stop ()
529-
530- def resumeProducing (self ):
531- """Implements IPushProducer"""
532- self ._start_loop ()
533-
534- def stopProducing (self ):
535- """Implements IPushProducer"""
536- self .stop ()
537- self ._producer .stopProducing ()
538-
539- def _run_once (self ):
540- """Calls resumeProducing on producer once."""
541-
542- try :
543- self ._producer .resumeProducing ()
544- except Exception :
545- logger .exception ("Failed to call resumeProducing" )
546- try :
547- self ._consumer .unregisterProducer ()
548- except Exception :
549- pass
550-
551- self .stopProducing ()
552-
553-
554458class FakeRedisPubSubServer :
555459 """A fake Redis server for pub/sub."""
556460
0 commit comments