@@ -628,3 +628,149 @@ async def test_multiple_semaphores_different_keys(
628628 ),
629629 ):
630630 ...
631+
632+
633+ async def test_decorator_exception_reraise_with_auto_renewal (
634+ redis_client_sdk : RedisClientSDK ,
635+ semaphore_name : str ,
636+ semaphore_capacity : int ,
637+ short_ttl : datetime .timedelta ,
638+ ):
639+ """Test that user function exceptions are properly re-raised even with auto-renewal running"""
640+
641+ class UserFunctionError (Exception ):
642+ """Custom exception to ensure we're catching the right exception"""
643+
644+ work_started = asyncio .Event ()
645+ renewal_count = 0
646+
647+ # Track that auto-renewal is actually happening
648+ original_renew = __import__ (
649+ "servicelib.redis._semaphore" , fromlist = ["renew_semaphore_entry" ]
650+ ).renew_semaphore_entry
651+
652+ async def tracking_renew_semaphore_entry (semaphore ):
653+ nonlocal renewal_count
654+ renewal_count += 1
655+ await original_renew (semaphore )
656+
657+ with mock .patch (
658+ "servicelib.redis._semaphore.renew_semaphore_entry" ,
659+ side_effect = tracking_renew_semaphore_entry ,
660+ ):
661+
662+ @with_limited_concurrency (
663+ redis_client_sdk ,
664+ key = semaphore_name ,
665+ capacity = semaphore_capacity ,
666+ ttl = short_ttl , # Short TTL to ensure renewal happens
667+ )
668+ async def failing_function ():
669+ work_started .set ()
670+ # Wait long enough for at least one renewal to happen
671+ await asyncio .sleep (short_ttl .total_seconds () * 0.8 )
672+ # Then raise our custom exception
673+ raise UserFunctionError ("User function failed intentionally" )
674+
675+ # Verify the exception is properly re-raised
676+ with pytest .raises (
677+ UserFunctionError , match = "User function failed intentionally"
678+ ):
679+ await failing_function ()
680+
681+ # Ensure work actually started
682+ assert work_started .is_set ()
683+
684+ # Verify auto-renewal was working (at least one renewal should have happened)
685+ assert renewal_count >= 1 , "Auto-renewal should have been called at least once"
686+
687+ # Verify semaphore was properly released by trying to acquire it again
688+ test_semaphore = DistributedSemaphore (
689+ redis_client = redis_client_sdk ,
690+ key = semaphore_name ,
691+ capacity = semaphore_capacity ,
692+ ttl = short_ttl ,
693+ )
694+ assert (
695+ await test_semaphore .get_current_count () == 0
696+ ), "Semaphore should be released after exception"
697+
698+
699+ async def test_decorator_exception_types_preserved (
700+ redis_client_sdk : RedisClientSDK ,
701+ semaphore_name : str ,
702+ semaphore_capacity : int ,
703+ ):
704+ """Test that different exception types are preserved through the decorator"""
705+
706+ # Test ValueError
707+ @with_limited_concurrency (
708+ redis_client_sdk ,
709+ key = f"{ semaphore_name } _ValueError" ,
710+ capacity = semaphore_capacity ,
711+ )
712+ async def function_raising_value_error ():
713+ raise ValueError ("Invalid value" )
714+
715+ with pytest .raises (ValueError ) as exc_info :
716+ await function_raising_value_error ()
717+ assert str (exc_info .value ) == "Invalid value"
718+
719+ # Test TypeError
720+ @with_limited_concurrency (
721+ redis_client_sdk ,
722+ key = f"{ semaphore_name } _TypeError" ,
723+ capacity = semaphore_capacity ,
724+ )
725+ async def function_raising_type_error ():
726+ raise TypeError ("Wrong type" )
727+
728+ with pytest .raises (TypeError ) as exc_info :
729+ await function_raising_type_error ()
730+ assert str (exc_info .value ) == "Wrong type"
731+
732+ # Test KeyError
733+ @with_limited_concurrency (
734+ redis_client_sdk ,
735+ key = f"{ semaphore_name } _KeyError" ,
736+ capacity = semaphore_capacity ,
737+ )
738+ async def function_raising_key_error ():
739+ raise KeyError ("Missing key" )
740+
741+ with pytest .raises (KeyError ) as exc_info :
742+ await function_raising_key_error ()
743+ assert str (exc_info .value ) == "'Missing key'" # KeyError adds quotes
744+
745+ # Test RuntimeError
746+ @with_limited_concurrency (
747+ redis_client_sdk ,
748+ key = f"{ semaphore_name } _RuntimeError" ,
749+ capacity = semaphore_capacity ,
750+ )
751+ async def function_raising_runtime_error ():
752+ raise RuntimeError ("Runtime error" )
753+
754+ with pytest .raises (RuntimeError ) as exc_info :
755+ await function_raising_runtime_error ()
756+ assert str (exc_info .value ) == "Runtime error"
757+
758+
759+ async def test_decorator_cancelled_error_preserved (
760+ redis_client_sdk : RedisClientSDK ,
761+ semaphore_name : str ,
762+ semaphore_capacity : int ,
763+ ):
764+ """Test that CancelledError is properly preserved through the decorator"""
765+
766+ @with_limited_concurrency (
767+ redis_client_sdk ,
768+ key = semaphore_name ,
769+ capacity = semaphore_capacity ,
770+ )
771+ async def function_raising_cancelled_error ():
772+ raise asyncio .CancelledError
773+
774+ # Verify CancelledError is preserved
775+ with pytest .raises (asyncio .CancelledError ):
776+ await function_raising_cancelled_error ()
0 commit comments