Skip to content

Commit 406ff86

Browse files
Attack wave samples (#360)
1 parent f06b49b commit 406ff86

File tree

12 files changed

+311
-20
lines changed

12 files changed

+311
-20
lines changed

lib/agent/aikido_types/init_data.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ type AttackWaveState struct {
8989
IpQueues map[string]*SlidingWindow
9090
// Map of IP addresses to the last time an event was sent for that IP
9191
LastSent map[string]int64
92+
// Maximum number of samples to keep per IP, can not be higher than attackWaveThreshold
93+
MaxSamplesPerIP int
9294
}
9395

9496
type ListsConfigData struct {
@@ -238,11 +240,12 @@ func NewServerData() *ServerData {
238240
Packages: make(map[string]Package),
239241
PollingData: NewServerDataPolling(),
240242
AttackWave: AttackWaveState{
241-
Threshold: 15, // Default: 15 requests
242-
WindowSize: 1, // Default: 1 minute
243-
MinBetween: 20 * 60 * 1000, // Default: 20 minutes
244-
IpQueues: make(map[string]*SlidingWindow),
245-
LastSent: make(map[string]int64),
243+
Threshold: 15, // Default: 15 requests
244+
WindowSize: 1, // Default: 1 minute
245+
MinBetween: 20 * 60 * 1000, // Default: 20 minutes
246+
IpQueues: make(map[string]*SlidingWindow),
247+
LastSent: make(map[string]int64),
248+
MaxSamplesPerIP: 15,
246249
},
247250
}
248251
}

lib/agent/aikido_types/sliding_window.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package aikido_types
22

3+
type SuspiciousRequest struct {
4+
Method string `json:"method"`
5+
Url string `json:"url"`
6+
}
7+
38
// SlidingWindow represents a time-based sliding window counter.
49
// It maintains a queue of counts per time bucket and a running total.
510
type SlidingWindow struct {
6-
Total int // Running total of all counts in the window
7-
Queue Queue[int] // Queue of counts per time bucket
11+
Total int // Running total of all counts in the window
12+
Queue Queue[int] // Queue of counts per time bucket
13+
Samples []SuspiciousRequest // Sample requests collected for attack wave detection (max MaxSamplesPerIP)
814
}
915

1016
// NewSlidingWindow creates a new sliding window with the specified size.
@@ -40,6 +46,25 @@ func (sw *SlidingWindow) Increment() {
4046
sw.Total++
4147
}
4248

49+
// AddSample adds a sample request to the sliding window for attack wave detection.
50+
// It maintains a maximum of MaxSamplesPerIP unique samples (based on method and URL).
51+
func (sw *SlidingWindow) AddSample(method, url string, maxSamplesPerIP int) {
52+
// Check if this sample already exists
53+
for _, sample := range sw.Samples {
54+
if sample.Method == method && sample.Url == url {
55+
return // Already exists, skip
56+
}
57+
}
58+
59+
// Add the sample if we haven't reached the limit
60+
if len(sw.Samples) < maxSamplesPerIP {
61+
sw.Samples = append(sw.Samples, SuspiciousRequest{
62+
Method: method,
63+
Url: url,
64+
})
65+
}
66+
}
67+
4368
// IsEmpty returns true if the total count is zero.
4469
func (sw *SlidingWindow) IsEmpty() bool {
4570
return sw.Total == 0

lib/agent/aikido_types/sliding_window_test.go

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,222 @@ func TestSlidingWindowIntegration(t *testing.T) {
266266
assert.NotContains(t, windowMap, "endpoint3")
267267
})
268268
}
269+
270+
func TestAddSample(t *testing.T) {
271+
t.Run("adds sample to empty sliding window", func(t *testing.T) {
272+
sw := NewSlidingWindow()
273+
sw.AddSample("GET", "/api/users", 15)
274+
275+
assert.Equal(t, 1, len(sw.Samples))
276+
assert.Equal(t, "GET", sw.Samples[0].Method)
277+
assert.Equal(t, "/api/users", sw.Samples[0].Url)
278+
})
279+
280+
t.Run("adds multiple different samples", func(t *testing.T) {
281+
sw := NewSlidingWindow()
282+
sw.AddSample("GET", "/api/users", 15)
283+
sw.AddSample("POST", "/api/login", 15)
284+
sw.AddSample("DELETE", "/api/users/123", 15)
285+
286+
assert.Equal(t, 3, len(sw.Samples))
287+
assert.Equal(t, "GET", sw.Samples[0].Method)
288+
assert.Equal(t, "/api/users", sw.Samples[0].Url)
289+
assert.Equal(t, "POST", sw.Samples[1].Method)
290+
assert.Equal(t, "/api/login", sw.Samples[1].Url)
291+
assert.Equal(t, "DELETE", sw.Samples[2].Method)
292+
assert.Equal(t, "/api/users/123", sw.Samples[2].Url)
293+
})
294+
295+
t.Run("prevents duplicate samples with same method and URL", func(t *testing.T) {
296+
sw := NewSlidingWindow()
297+
sw.AddSample("GET", "/api/users", 15)
298+
sw.AddSample("GET", "/api/users", 15) // duplicate
299+
sw.AddSample("GET", "/api/users", 15) // duplicate
300+
301+
assert.Equal(t, 1, len(sw.Samples))
302+
assert.Equal(t, "GET", sw.Samples[0].Method)
303+
assert.Equal(t, "/api/users", sw.Samples[0].Url)
304+
})
305+
306+
t.Run("allows same URL with different methods", func(t *testing.T) {
307+
sw := NewSlidingWindow()
308+
sw.AddSample("GET", "/api/users", 15)
309+
sw.AddSample("POST", "/api/users", 15)
310+
sw.AddSample("DELETE", "/api/users", 15)
311+
312+
assert.Equal(t, 3, len(sw.Samples))
313+
assert.Equal(t, "GET", sw.Samples[0].Method)
314+
assert.Equal(t, "POST", sw.Samples[1].Method)
315+
assert.Equal(t, "DELETE", sw.Samples[2].Method)
316+
})
317+
318+
t.Run("allows same method with different URLs", func(t *testing.T) {
319+
sw := NewSlidingWindow()
320+
sw.AddSample("GET", "/api/users", 15)
321+
sw.AddSample("GET", "/api/posts", 15)
322+
sw.AddSample("GET", "/api/comments", 15)
323+
324+
assert.Equal(t, 3, len(sw.Samples))
325+
assert.Equal(t, "/api/users", sw.Samples[0].Url)
326+
assert.Equal(t, "/api/posts", sw.Samples[1].Url)
327+
assert.Equal(t, "/api/comments", sw.Samples[2].Url)
328+
})
329+
330+
t.Run("enforces maximum of 10 samples", func(t *testing.T) {
331+
sw := NewSlidingWindow()
332+
333+
// Add 12 unique samples
334+
for i := 0; i < 12; i++ {
335+
sw.AddSample("GET", "/api/endpoint"+string(rune('0'+i)), 10)
336+
}
337+
338+
assert.Equal(t, 10, len(sw.Samples))
339+
})
340+
341+
t.Run("does not add 11th sample even if unique", func(t *testing.T) {
342+
sw := NewSlidingWindow()
343+
344+
// Add exactly 10 samples
345+
for i := 0; i < 10; i++ {
346+
sw.AddSample("GET", "/api/endpoint"+string(rune('0'+i)), 10)
347+
}
348+
assert.Equal(t, 10, len(sw.Samples))
349+
350+
// Try to add an 11th unique sample
351+
sw.AddSample("POST", "/api/new-endpoint", 10)
352+
assert.Equal(t, 10, len(sw.Samples))
353+
354+
// Verify the 11th sample was not added
355+
found := false
356+
for _, sample := range sw.Samples {
357+
if sample.Method == "POST" && sample.Url == "/api/new-endpoint" {
358+
found = true
359+
break
360+
}
361+
}
362+
assert.False(t, found)
363+
})
364+
365+
t.Run("duplicates do not count toward 10 sample limit", func(t *testing.T) {
366+
sw := NewSlidingWindow()
367+
368+
// Add 5 unique samples
369+
for i := 0; i < 5; i++ {
370+
sw.AddSample("GET", "/api/endpoint"+string(rune('0'+i)), 15)
371+
}
372+
373+
// Try to add duplicates
374+
sw.AddSample("GET", "/api/endpoint0", 15) // duplicate
375+
sw.AddSample("GET", "/api/endpoint1", 15) // duplicate
376+
sw.AddSample("GET", "/api/endpoint2", 15) // duplicate
377+
378+
assert.Equal(t, 5, len(sw.Samples))
379+
380+
// Add 5 more unique samples to reach 10
381+
for i := 5; i < 10; i++ {
382+
sw.AddSample("GET", "/api/endpoint"+string(rune('0'+i)), 15)
383+
}
384+
assert.Equal(t, 10, len(sw.Samples))
385+
386+
// Try to add more duplicates - should still be 10
387+
sw.AddSample("GET", "/api/endpoint5", 15)
388+
sw.AddSample("GET", "/api/endpoint9", 15)
389+
assert.Equal(t, 10, len(sw.Samples))
390+
})
391+
392+
t.Run("preserves samples during window operations", func(t *testing.T) {
393+
sw := NewSlidingWindow()
394+
sw.AddSample("GET", "/api/users", 15)
395+
sw.AddSample("POST", "/api/login", 15)
396+
sw.Increment()
397+
398+
// Advance the window
399+
sw.Advance(5)
400+
401+
// Samples should still be present
402+
assert.Equal(t, 2, len(sw.Samples))
403+
assert.Equal(t, "GET", sw.Samples[0].Method)
404+
assert.Equal(t, "/api/users", sw.Samples[0].Url)
405+
assert.Equal(t, "POST", sw.Samples[1].Method)
406+
assert.Equal(t, "/api/login", sw.Samples[1].Url)
407+
})
408+
409+
t.Run("empty method and URL are valid samples", func(t *testing.T) {
410+
sw := NewSlidingWindow()
411+
sw.AddSample("", "", 15)
412+
sw.AddSample("GET", "", 15)
413+
sw.AddSample("", "/api/users", 15)
414+
415+
assert.Equal(t, 3, len(sw.Samples))
416+
})
417+
}
418+
419+
func TestSlidingWindowSamplesIntegration(t *testing.T) {
420+
t.Run("simulates attack wave detection with samples", func(t *testing.T) {
421+
// Create a map simulating per-IP tracking
422+
ipMap := map[string]*SlidingWindow{
423+
"192.168.1.100": NewSlidingWindow(),
424+
}
425+
426+
ip := "192.168.1.100"
427+
sw := ipMap[ip]
428+
429+
// Simulate suspicious requests
430+
requests := []struct {
431+
method string
432+
url string
433+
}{
434+
{"GET", "/admin"},
435+
{"GET", "/admin"}, // duplicate
436+
{"POST", "/admin"},
437+
{"GET", "/wp-admin"},
438+
{"GET", "/.env"},
439+
{"GET", "/config.php"},
440+
{"POST", "/login"},
441+
{"GET", "/admin"}, // duplicate
442+
}
443+
444+
for _, req := range requests {
445+
sw.Increment()
446+
sw.AddSample(req.method, req.url, 15)
447+
}
448+
449+
// Should have 6 unique samples (2 duplicates removed)
450+
assert.Equal(t, 6, len(sw.Samples))
451+
assert.Equal(t, 8, sw.Total) // But total count should be 8
452+
453+
// Verify samples are unique
454+
uniqueCheck := make(map[string]bool)
455+
for _, sample := range sw.Samples {
456+
key := sample.Method + ":" + sample.Url
457+
assert.False(t, uniqueCheck[key], "Found duplicate sample: "+key)
458+
uniqueCheck[key] = true
459+
}
460+
})
461+
462+
t.Run("samples persist across window advances until removal", func(t *testing.T) {
463+
windowMap := map[string]*SlidingWindow{
464+
"10.0.0.1": NewSlidingWindow(),
465+
}
466+
467+
sw := windowMap["10.0.0.1"]
468+
sw.AddSample("GET", "/api/v1/users", 15)
469+
sw.AddSample("POST", "/api/v1/login", 15)
470+
sw.Increment()
471+
sw.Increment()
472+
473+
// Advance window multiple times
474+
AdvanceSlidingWindowMap(windowMap, 3)
475+
AdvanceSlidingWindowMap(windowMap, 3)
476+
477+
// Samples should still be there
478+
assert.Contains(t, windowMap, "10.0.0.1")
479+
assert.Equal(t, 2, len(windowMap["10.0.0.1"].Samples))
480+
481+
// Advance until window is empty
482+
AdvanceSlidingWindowMap(windowMap, 3)
483+
484+
// Window should be removed when total reaches 0
485+
assert.NotContains(t, windowMap, "10.0.0.1")
486+
})
487+
}

lib/agent/grpc/request.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ func isRateLimitingThresholdExceeded(config *RateLimitingConfig, countsMap map[s
201201

202202
// updateAttackWaveCountsAndDetect implements the attack wave detection logic:
203203
// 1. Validates the request is from a web scanner and has a valid IP address
204-
// 2. Increments the sliding window counter for this IP
204+
// 2. Increments the sliding window counter for this IP and collects request samples
205205
// 3. Applies throttling: if an event was recently sent for this IP (within minBetween window),
206206
// returns early without checking threshold or sending another event
207207
// 4. Checks if the total count within the sliding window exceeds the threshold
208-
// 5. If threshold exceeded: records the event time on the queue, logs the detection, and sends event to cloud
209-
func updateAttackWaveCountsAndDetect(server *ServerData, isWebScanner bool, ip string, userId string, userAgent string) bool {
208+
// 5. If threshold exceeded: records the event time on the queue, logs the detection, and sends event with samples to cloud
209+
func updateAttackWaveCountsAndDetect(server *ServerData, isWebScanner bool, ip string, userId string, userAgent string, method string, url string) bool {
210210
if !isWebScanner || ip == "" {
211211
return false
212212
}
@@ -224,6 +224,11 @@ func updateAttackWaveCountsAndDetect(server *ServerData, isWebScanner bool, ip s
224224
return false
225225
}
226226

227+
// Add this request as a sample
228+
if queue != nil {
229+
queue.AddSample(method, url, server.AttackWave.MaxSamplesPerIP)
230+
}
231+
227232
// check threshold within window
228233
if queue == nil || queue.Total < server.AttackWave.Threshold {
229234
return false // threshold not reached
@@ -234,11 +239,18 @@ func updateAttackWaveCountsAndDetect(server *ServerData, isWebScanner bool, ip s
234239
if server.Logger != nil {
235240
log.Infof(server.Logger, "Attack wave detected from IP: %s", ip)
236241
}
242+
237243
// report event to cloud
238244
cloud.SendAttackDetectedEvent(server, &protos.AttackDetected{
239245
Token: server.AikidoConfig.Token,
240246
Request: &protos.Request{IpAddress: ip, UserAgent: userAgent},
241-
Attack: &protos.Attack{Metadata: []*protos.Metadata{}, UserId: userId},
247+
Attack: &protos.Attack{
248+
Metadata: []*protos.Metadata{{
249+
Key: "samples",
250+
Value: utils.JsonMarshal(queue.Samples),
251+
}},
252+
UserId: userId,
253+
},
242254
}, "detected_attack_wave")
243255
return true
244256
}

lib/agent/grpc/request_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestAttackWaveThrottling(t *testing.T) {
3737
server.AttackWave.IpQueues[ip] = sw
3838

3939
// Should return false (throttled) because last event was only 30 seconds ago (< 60s MinBetween)
40-
assert.False(t, updateAttackWaveCountsAndDetect(server, true, ip, "", ""))
40+
assert.False(t, updateAttackWaveCountsAndDetect(server, true, ip, "", "", "", ""))
4141
})
4242

4343
t.Run("returns true and populates LastSent map when IP reaches threshold for first time", func(t *testing.T) {
@@ -66,7 +66,7 @@ func TestAttackWaveThrottling(t *testing.T) {
6666
assert.False(t, exists, "IP should not be in LastSent map before threshold is reached")
6767

6868
// Should return true (event sent) because this is the first time reaching threshold
69-
assert.True(t, updateAttackWaveCountsAndDetect(server, true, ip, "", ""))
69+
assert.True(t, updateAttackWaveCountsAndDetect(server, true, ip, "", "", "", ""))
7070

7171
// Verify LastSent map was populated
7272
assert.True(t, server.AttackWave.LastSent[ip] > 0, "LastSent should be set after event is sent")

lib/agent/grpc/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s *GrpcServer) OnRequestShutdown(ctx context.Context, req *protos.RequestM
8282
go storeRoute(server, req.GetMethod(), req.GetRouteParsed(), req.GetApiSpec(), req.GetRateLimited())
8383
go updateRateLimitingCounts(server, req.GetMethod(), req.GetRoute(), req.GetRouteParsed(), req.GetUser(), req.GetIp(), req.GetRateLimitGroup())
8484
}
85-
go updateAttackWaveCountsAndDetect(server, req.GetIsWebScanner(), req.GetIp(), req.GetUser(), req.GetUserAgent())
85+
go updateAttackWaveCountsAndDetect(server, req.GetIsWebScanner(), req.GetIp(), req.GetUser(), req.GetUserAgent(), req.GetMethod(), req.GetUrl())
8686

8787
atomic.StoreUint32(&server.GotTraffic, 1)
8888
return &emptypb.Empty{}, nil

lib/agent/utils/utils.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package utils
22

33
import (
4+
"encoding/json"
45
"fmt"
56
. "main/aikido_types"
67
"main/config"
@@ -146,3 +147,11 @@ func AnonymizeToken(token string) string {
146147
}
147148
return token[len(token)-4:]
148149
}
150+
151+
func JsonMarshal(v any) string {
152+
bytes, err := json.Marshal(v)
153+
if err != nil {
154+
return ""
155+
}
156+
return string(bytes)
157+
}

0 commit comments

Comments
 (0)