@@ -346,17 +346,173 @@ class XTransport(Transport):
346346 Channel = XChannel
347347
348348 conn = Connection (transport = XTransport )
349+ conn .transport .cycle = Mock (name = 'cycle' )
349350 client .ping .side_effect = RuntimeError ()
350351 with pytest .raises (RuntimeError ):
351352 conn .channel ()
352353 pool .disconnect .assert_called_with ()
353354 pool .disconnect .reset_mock ()
355+ # Ensure that the channel without ensured connection to Redis
356+ # won't be added to the cycle.
357+ conn .transport .cycle .add .assert_not_called ()
358+ assert len (conn .transport .channels ) == 0
354359
355360 pool_at_init = [None ]
356361 with pytest .raises (RuntimeError ):
357362 conn .channel ()
358363 pool .disconnect .assert_not_called ()
359364
365+ def test_redis_connection_added_to_cycle_if_ping_succeeds (self ):
366+ """Test should check the connection is added to the cycle only
367+ if the ping to Redis was finished successfully."""
368+ # given: mock pool and client
369+ pool = Mock (name = 'pool' )
370+ client = Mock (name = 'client' )
371+
372+ # override channel class with given mocks
373+ class XChannel (Channel ):
374+ def __init__ (self , * args , ** kwargs ):
375+ self ._pool = pool
376+ super ().__init__ (* args , ** kwargs )
377+
378+ def _get_client (self ):
379+ return lambda * _ , ** __ : client
380+
381+ # override Channel in Transport with given channel
382+ class XTransport (Transport ):
383+ Channel = XChannel
384+
385+ # when: create connection with overridden transport
386+ conn = Connection (transport = XTransport )
387+ conn .transport .cycle = Mock (name = 'cycle' )
388+ # create the channel
389+ chan = conn .channel ()
390+ # then: check if ping was called
391+ client .ping .assert_called_once ()
392+ # the connection was added to the cycle
393+ conn .transport .cycle .add .assert_called_once ()
394+ assert len (conn .transport .channels ) == 1
395+ # the channel was flaged as registered into poller
396+ assert chan ._registered
397+
398+ def test_redis_on_disconnect_channel_only_if_was_registered (self ):
399+ """Test shoud check if the _on_disconnect method is called only
400+ if the channel was registered into the poller."""
401+ # given: mock pool and client
402+ pool = Mock (name = 'pool' )
403+ client = Mock (
404+ name = 'client' ,
405+ ping = Mock (return_value = True )
406+ )
407+
408+ # create RedisConnectionMock class
409+ # for the possibility to run disconnect method
410+ class RedisConnectionMock :
411+ def disconnect (self , * args ):
412+ pass
413+
414+ # override Channel method with given mocks
415+ class XChannel (Channel ):
416+ connection_class = RedisConnectionMock
417+
418+ def __init__ (self , * args , ** kwargs ):
419+ self ._pool = pool
420+ # counter to check if the method was called
421+ self .on_disconect_count = 0
422+ super ().__init__ (* args , ** kwargs )
423+
424+ def _get_client (self ):
425+ return lambda * _ , ** __ : client
426+
427+ def _on_connection_disconnect (self , connection ):
428+ # increment the counter when the method is called
429+ self .on_disconect_count += 1
430+
431+ # create the channel
432+ chan = XChannel (Mock (
433+ _used_channel_ids = [],
434+ channel_max = 1 ,
435+ channels = [],
436+ client = Mock (
437+ transport_options = {},
438+ hostname = "127.0.0.1" ,
439+ virtual_host = None )))
440+ # create the _connparams with overriden connection_class
441+ connparams = chan ._connparams (asynchronous = True )
442+ # create redis.Connection
443+ redis_connection = connparams ['connection_class' ]()
444+ # the connection was added to the cycle
445+ chan .connection .cycle .add .assert_called_once ()
446+ # and the ping was called
447+ client .ping .assert_called_once ()
448+ # the channel was registered
449+ assert chan ._registered
450+ # than disconnect the Redis connection
451+ redis_connection .disconnect ()
452+ # the on_disconnect counter should be incremented
453+ assert chan .on_disconect_count == 1
454+
455+ def test_redis__on_disconnect_should_not_be_called_if_not_registered (self ):
456+ """Test should check if the _on_disconnect method is not called because
457+ the connection to Redis isn't established properly."""
458+ # given: mock pool
459+ pool = Mock (name = 'pool' )
460+ # client mock with ping method which return ConnectionError
461+ from redis .exceptions import ConnectionError
462+ client = Mock (
463+ name = 'client' ,
464+ ping = Mock (side_effect = ConnectionError ())
465+ )
466+
467+ # create RedisConnectionMock
468+ # for the possibility to run disconnect method
469+ class RedisConnectionMock :
470+ def disconnect (self , * args ):
471+ pass
472+
473+ # override Channel method with given mocks
474+ class XChannel (Channel ):
475+ connection_class = RedisConnectionMock
476+
477+ def __init__ (self , * args , ** kwargs ):
478+ self ._pool = pool
479+ # counter to check if the method was called
480+ self .on_disconect_count = 0
481+ super ().__init__ (* args , ** kwargs )
482+
483+ def _get_client (self ):
484+ return lambda * _ , ** __ : client
485+
486+ def _on_connection_disconnect (self , connection ):
487+ # increment the counter when the method is called
488+ self .on_disconect_count += 1
489+
490+ # then: exception was risen
491+ with pytest .raises (ConnectionError ):
492+ # when: create the channel
493+ chan = XChannel (Mock (
494+ _used_channel_ids = [],
495+ channel_max = 1 ,
496+ channels = [],
497+ client = Mock (
498+ transport_options = {},
499+ hostname = "127.0.0.1" ,
500+ virtual_host = None )))
501+ # create the _connparams with overriden connection_class
502+ connparams = chan ._connparams (asynchronous = True )
503+ # create redis.Connection
504+ redis_connection = connparams ['connection_class' ]()
505+ # the connection wasn't added to the cycle
506+ chan .connection .cycle .add .assert_not_called ()
507+ # the ping was called once with the exception
508+ client .ping .assert_called_once ()
509+ # the channel was not registered
510+ assert not chan ._registered
511+ # then: disconnect the Redis connection
512+ redis_connection .disconnect ()
513+ # the on_disconnect counter shouldn't be incremented
514+ assert chan .on_disconect_count == 0
515+
360516 def test_get_redis_ConnectionError (self ):
361517 from redis .exceptions import ConnectionError
362518
0 commit comments