@@ -138,35 +138,40 @@ class ApiSecuritySamplerTest extends DDSpecification {
138138 ! sampled
139139 }
140140
141- void ' sampleRequest honors expiration' () {
141+ void ' preSampleRequest honors expiration' () {
142142 given :
143- def ctx = createContext(' route1' , ' GET' , 200 )
144- ctx . setApiSecurityEndpointHash( 42L )
145- ctx . setKeepOpenForApiSecurityPostProcessing( true )
143+ def ctx1 = createContext(' route1' , ' GET' , 200 )
144+ def ctx2 = createContext( ' route1 ' , ' GET ' , 200 )
145+ def ctx3 = createContext( ' route1 ' , ' GET ' , 200 )
146146 final timeSource = new ControllableTimeSource ()
147147 timeSource. set(0 )
148148 final long expirationTimeInMs = 10L
149149 final long expirationTimeInNs = expirationTimeInMs * 1_000_000
150150 def sampler = new ApiSecuritySamplerImpl (10 , expirationTimeInMs, timeSource)
151151
152- when :
153- def sampled = sampler. sampleRequest(ctx)
152+ when : ' first request samples'
153+ def preSampled1 = sampler. preSampleRequest(ctx1)
154+ def sampled1 = sampler. sampleRequest(ctx1)
154155
155156 then :
156- sampled
157+ preSampled1
158+ sampled1
157159
158- when :
159- sampled = sampler. sampleRequest(ctx )
160+ when : ' second request to same endpoint before expiration '
161+ def preSampled2 = sampler. preSampleRequest(ctx2 )
160162
161163 then : ' second request is not sampled'
162- ! sampled
164+ ! preSampled2
163165
164166 when : ' expiration time has passed'
167+ sampler. releaseOne()
165168 timeSource. advance(expirationTimeInNs)
166- sampled = sampler. sampleRequest(ctx)
169+ def preSampled3 = sampler. preSampleRequest(ctx3)
170+ def sampled3 = sampler. sampleRequest(ctx3)
167171
168172 then : ' request is sampled again'
169- sampled
173+ preSampled3
174+ sampled3
170175 }
171176
172177 void ' internal accessMap never goes beyond capacity' () {
@@ -198,10 +203,13 @@ class ApiSecuritySamplerTest extends DDSpecification {
198203
199204 expect :
200205 for (int i = 0 ; i < maxCapacity * 10 ; i++ ) {
201- final ctx = createContext(' route1' , ' GET' , 200 + 1 )
202- ctx. setApiSecurityEndpointHash(i as long )
203- ctx. setKeepOpenForApiSecurityPostProcessing(true )
204- assert sampler. sampleRequest(ctx)
206+ final ctx = createContext(' route1' , ' GET' , 200 + i)
207+ def preSampled = sampler. preSampleRequest(ctx)
208+ // First request always samples, then we advance time so each subsequent request expires
209+ assert preSampled
210+ def sampled = sampler. sampleRequest(ctx)
211+ assert sampled
212+ sampler. releaseOne()
205213 assert sampler. accessMap. size() <= 2
206214 if (i % 2 ) {
207215 timeSource. advance(expirationTimeInMs * 1_000_000)
0 commit comments