2424using Microsoft . Extensions . DependencyInjection ;
2525using Microsoft . Extensions . Hosting ;
2626using Microsoft . Extensions . Logging ;
27+ using Microsoft . VisualStudio . TestPlatform ;
2728using Moq ;
29+ using Xunit . Sdk ;
2830
2931namespace Microsoft . AspNetCore . Server . Kestrel . Core . Tests ;
3032
@@ -50,6 +52,30 @@ public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List<byte[]
5052 public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments ( int id , List < byte [ ] > packets )
5153 => RunTlsClientHelloCallbackTest_WithMultipleSegments ( id , packets , tlsClientHelloCallbackExpected : false ) ;
5254
55+ [ Fact ]
56+ public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken ( )
57+ {
58+ var serviceContext = new TestServiceContext ( ) ;
59+
60+ var pipe = new Pipe ( ) ;
61+ var writer = pipe . Writer ;
62+ var reader = new ObservablePipeReader ( pipe . Reader ) ;
63+
64+ var transport = new DuplexPipe ( reader , writer ) ;
65+ var transportConnection = new DefaultConnectionContext ( "test" , transport , transport ) ;
66+
67+ var tlsClientHelloCallbackInvoked = false ;
68+ var listener = new TlsListener ( ( ctx , data ) => { tlsClientHelloCallbackInvoked = true ; } ) ;
69+
70+ var cts = new CancellationTokenSource ( TimeSpan . FromMilliseconds ( 3 ) ) ;
71+
72+ await writer . WriteAsync ( new byte [ 1 ] { 0x16 } ) ;
73+ await VerifyThrowsAnyAsync (
74+ async ( ) => await listener . OnTlsClientHelloAsync ( transportConnection , cts . Token ) ,
75+ typeof ( OperationCanceledException ) , typeof ( TaskCanceledException ) ) ;
76+ Assert . False ( tlsClientHelloCallbackInvoked ) ;
77+ }
78+
5379 [ Fact ]
5480 public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken ( )
5581 {
@@ -69,10 +95,9 @@ public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken()
6995 cts . Cancel ( ) ;
7096
7197 await writer . WriteAsync ( new byte [ 1 ] { 0x16 } ) ;
72- await Assert . ThrowsAsync < OperationCanceledException > ( async ( ) =>
73- {
74- await listener . OnTlsClientHelloAsync ( transportConnection , cts . Token ) ;
75- } ) ;
98+ await VerifyThrowsAnyAsync (
99+ async ( ) => await listener . OnTlsClientHelloAsync ( transportConnection , cts . Token ) ,
100+ typeof ( OperationCanceledException ) , typeof ( TaskCanceledException ) ) ;
76101 Assert . False ( tlsClientHelloCallbackInvoked ) ;
77102 }
78103
@@ -598,4 +623,28 @@ public static IEnumerable<object[]> InvalidClientHelloData_Segmented()
598623 _invalidTlsClientHelloHeader , _invalid3BytesMessage , _invalid9BytesMessage ,
599624 _invalidUnknownProtocolVersion1 , _invalidUnknownProtocolVersion2 , _invalidIncorrectHandshakeMessageType
600625 } ;
626+
627+ static async Task VerifyThrowsAnyAsync ( Func < Task > code , params Type [ ] exceptionTypes )
628+ {
629+ if ( exceptionTypes == null || exceptionTypes . Length == 0 )
630+ {
631+ throw new ArgumentException ( "At least one exception type must be provided." , nameof ( exceptionTypes ) ) ;
632+ }
633+
634+ try
635+ {
636+ await code ( ) ;
637+ }
638+ catch ( Exception ex )
639+ {
640+ if ( exceptionTypes . Any ( type => type . IsInstanceOfType ( ex ) ) )
641+ {
642+ return ;
643+ }
644+
645+ throw ThrowsException . ForIncorrectExceptionType ( exceptionTypes . First ( ) , ex ) ;
646+ }
647+
648+ throw ThrowsException . ForNoException ( exceptionTypes . First ( ) ) ;
649+ }
601650}
0 commit comments