1111using Microsoft . Extensions . Logging ;
1212using System . Globalization ;
1313using System . Net ;
14+ using System . Text ;
1415using System . Text . Json ;
1516using System . Text . RegularExpressions ;
16- using Unobtanium . Web . Proxy . Http ;
17- using Unobtanium . Web . Proxy . Models ;
1817
1918namespace DevProxy . Plugins . Behavior ;
2019
2120public sealed class RetryAfterPlugin (
2221 ILogger < RetryAfterPlugin > logger ,
23- ISet < UrlToWatch > urlsToWatch ) : BasePlugin ( logger , urlsToWatch )
22+ ISet < UrlToWatch > urlsToWatch ,
23+ IProxyStorage proxyStorage ) : BasePlugin ( logger , urlsToWatch )
2424{
25+ private readonly IProxyStorage _proxyStorage = proxyStorage ;
2526 public static readonly string ThrottledRequestsKey = "ThrottledRequests" ;
2627
2728 public override string Name => nameof ( RetryAfterPlugin ) ;
2829
29- public override Task BeforeRequestAsync ( ProxyRequestArgs e , CancellationToken cancellationToken )
30+ public override Func < RequestArguments , CancellationToken , Task < PluginResponse > > ? OnRequestAsync => ( args , cancellationToken ) =>
3031 {
31- Logger . LogTrace ( "{Method} called" , nameof ( BeforeRequestAsync ) ) ;
32+ Logger . LogTrace ( "{Method} called" , nameof ( OnRequestAsync ) ) ;
3233
33- ArgumentNullException . ThrowIfNull ( e ) ;
34-
35- if ( ! e . HasRequestUrlMatch ( UrlsToWatch ) )
34+ if ( ! ProxyUtils . MatchesUrlToWatch ( UrlsToWatch , args . Request . RequestUri ) )
3635 {
37- Logger . LogRequest ( "URL not matched" , MessageType . Skipped , new ( e . Session ) ) ;
38- return Task . CompletedTask ;
36+ Logger . LogRequest ( "URL not matched" , MessageType . Skipped , args . Request ) ;
37+ return Task . FromResult ( PluginResponse . Continue ( ) ) ;
3938 }
40- if ( e . ResponseState . HasBeenSet )
39+
40+ if ( args . Request . Method == HttpMethod . Options )
4141 {
42- Logger . LogRequest ( "Response already set " , MessageType . Skipped , new ( e . Session ) ) ;
43- return Task . CompletedTask ;
42+ Logger . LogRequest ( "Skipping OPTIONS request " , MessageType . Skipped , args . Request ) ;
43+ return Task . FromResult ( PluginResponse . Continue ( ) ) ;
4444 }
45- if ( string . Equals ( e . Session . HttpClient . Request . Method , "OPTIONS" , StringComparison . OrdinalIgnoreCase ) )
45+
46+ var throttleResponse = CheckIfThrottled ( args . Request ) ;
47+ if ( throttleResponse != null )
4648 {
47- Logger . LogRequest ( "Skipping OPTIONS request" , MessageType . Skipped , new ( e . Session ) ) ;
48- return Task . CompletedTask ;
49+ return Task . FromResult ( PluginResponse . Respond ( throttleResponse ) ) ;
4950 }
5051
51- ThrottleIfNecessary ( e ) ;
52-
53- Logger . LogTrace ( "Left {Name}" , nameof ( BeforeRequestAsync ) ) ;
54- return Task . CompletedTask ;
55- }
52+ Logger . LogTrace ( "Left {Name}" , nameof ( OnRequestAsync ) ) ;
53+ return Task . FromResult ( PluginResponse . Continue ( ) ) ;
54+ } ;
5655
57- private void ThrottleIfNecessary ( ProxyRequestArgs e )
56+ private HttpResponseMessage ? CheckIfThrottled ( HttpRequestMessage request )
5857 {
59- var request = e . Session . HttpClient . Request ;
60- if ( ! e . GlobalData . TryGetValue ( ThrottledRequestsKey , out var value ) )
58+ if ( ! _proxyStorage . GlobalData . TryGetValue ( ThrottledRequestsKey , out var value ) )
6159 {
62- Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , new ( e . Session ) ) ;
63- return ;
60+ Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , request ) ;
61+ return null ;
6462 }
6563
6664 if ( value is not List < ThrottlerInfo > throttledRequests )
6765 {
68- Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , new ( e . Session ) ) ;
69- return ;
66+ Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , request ) ;
67+ return null ;
7068 }
7169
7270 var expiredThrottlers = throttledRequests . Where ( t => t . ResetTime < DateTime . Now ) . ToArray ( ) ;
@@ -77,32 +75,31 @@ private void ThrottleIfNecessary(ProxyRequestArgs e)
7775
7876 if ( throttledRequests . Count == 0 )
7977 {
80- Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , new ( e . Session ) ) ;
81- return ;
78+ Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , request ) ;
79+ return null ;
8280 }
8381
8482 foreach ( var throttler in throttledRequests )
8583 {
8684 var throttleInfo = throttler . ShouldThrottle ( request , throttler . ThrottlingKey ) ;
8785 if ( throttleInfo . ThrottleForSeconds > 0 )
8886 {
89- var message = $ "Calling { request . Url } before waiting for the Retry-After period. Request will be throttled. Throttling on { throttler . ThrottlingKey } .";
90- Logger . LogRequest ( message , MessageType . Failed , new ( e . Session ) ) ;
87+ var message = $ "Calling { request . RequestUri } before waiting for the Retry-After period. Request will be throttled. Throttling on { throttler . ThrottlingKey } .";
88+ Logger . LogRequest ( message , MessageType . Failed , request ) ;
9189
9290 throttler . ResetTime = DateTime . Now . AddSeconds ( throttleInfo . ThrottleForSeconds ) ;
93- UpdateProxyResponse ( e , throttleInfo , string . Join ( ' ' , message ) ) ;
94- return ;
91+ return BuildThrottleResponse ( request , throttleInfo , string . Join ( ' ' , message ) ) ;
9592 }
9693 }
9794
98- Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , new ( e . Session ) ) ;
95+ Logger . LogRequest ( "Request not throttled" , MessageType . Skipped , request ) ;
96+ return null ;
9997 }
10098
101- private static void UpdateProxyResponse ( ProxyRequestArgs e , ThrottlingInfo throttlingInfo , string message )
99+ private static HttpResponseMessage BuildThrottleResponse ( HttpRequestMessage request , ThrottlingInfo throttlingInfo , string message )
102100 {
103101 var headers = new List < MockResponseHeader > ( ) ;
104102 var body = string . Empty ;
105- var request = e . Session . HttpClient . Request ;
106103
107104 // override the response body and headers for the error response
108105 if ( ProxyUtils . IsGraphRequest ( request ) )
@@ -128,7 +125,7 @@ private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throt
128125 else
129126 {
130127 // ProxyUtils.BuildGraphResponseHeaders already includes CORS headers
131- if ( request . Headers . Any ( h => h . Name . Equals ( "Origin" , StringComparison . OrdinalIgnoreCase ) ) )
128+ if ( request . Headers . TryGetValues ( "Origin" , out var _ ) )
132129 {
133130 headers . Add ( new ( "Access-Control-Allow-Origin" , "*" ) ) ;
134131 headers . Add ( new ( "Access-Control-Expose-Headers" , throttlingInfo . RetryAfterHeaderName ) ) ;
@@ -137,9 +134,18 @@ private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throt
137134
138135 headers . Add ( new ( throttlingInfo . RetryAfterHeaderName , throttlingInfo . ThrottleForSeconds . ToString ( CultureInfo . InvariantCulture ) ) ) ;
139136
140- e . Session . GenericResponse ( body ?? string . Empty , HttpStatusCode . TooManyRequests , headers . Select ( h => new HttpHeader ( h . Name , h . Value ) ) ) ;
141- e . ResponseState . HasBeenSet = true ;
137+ var response = new HttpResponseMessage ( HttpStatusCode . TooManyRequests )
138+ {
139+ Content = new StringContent ( body ?? string . Empty , Encoding . UTF8 , "application/json" )
140+ } ;
141+
142+ foreach ( var header in headers )
143+ {
144+ _ = response . Headers . TryAddWithoutValidation ( header . Name , header . Value ) ;
145+ }
146+
147+ return response ;
142148 }
143149
144- private static string BuildApiErrorMessage ( Request r , string message ) => $ "{ message } { ( ProxyUtils . IsGraphRequest ( r ) ? ProxyUtils . IsSdkRequest ( r ) ? "" : string . Join ( ' ' , MessageUtils . BuildUseSdkForErrorsMessage ( ) ) : "" ) } ";
150+ private static string BuildApiErrorMessage ( HttpRequestMessage r , string message ) => $ "{ message } { ( ProxyUtils . IsGraphRequest ( r ) ? ProxyUtils . IsSdkRequest ( r ) ? "" : string . Join ( ' ' , MessageUtils . BuildUseSdkForErrorsMessage ( ) ) : "" ) } ";
145151}
0 commit comments