Skip to content

Commit 769d922

Browse files
Merge pull request #112 from AikidoSec/AIK-4387
AIK-4387 Make sure route hit reporting works for Java
2 parents 1a90bd8 + 480d1ee commit 769d922

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

agent_api/src/main/java/dev/aikido/agent_api/storage/Hostnames.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ public void add(String hostname, int port) {
2424
}
2525
public void addArray(HostnameEntry[] hostnameEntries) {
2626
for (HostnameEntry entry: hostnameEntries) {
27-
add(entry.getHostname(), entry.getPort());
27+
String key = getKey(entry.getHostname(), entry.getPort());
28+
if (map.containsKey(key)) {
29+
// Merge hits :
30+
map.get(key).incrementHits(entry.getHits());
31+
} else {
32+
map.put(key, entry);
33+
}
2834
}
2935
}
3036
public HostnameEntry[] asArray() {
@@ -53,6 +59,10 @@ public HostnameEntry(String hostname, int port) {
5359
public void incrementHits() {
5460
hits++;
5561
}
62+
public void incrementHits(int delta) {
63+
hits += delta;
64+
}
65+
5666

5767
public String getHostname() {
5868
return hostname;

agent_api/src/test/java/storage/HostnamesTest.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,78 @@ public void testExceedMaxEntriesWithMultiplePorts() {
105105
assertTrue(containsEntry(entries, "test.com", 8080));
106106
assertTrue(containsEntry(entries, "newsite.com", 3000));
107107
}
108+
@Test
109+
public void testAddArrayWithNewEntries() {
110+
Hostnames.HostnameEntry[] entriesToAdd = {
111+
new Hostnames.HostnameEntry("example.com", 80),
112+
new Hostnames.HostnameEntry("test.com", 443)
113+
};
114+
entriesToAdd[0].incrementHits();
115+
entriesToAdd[1].incrementHits();
116+
117+
hostnames.addArray(entriesToAdd);
118+
119+
Hostnames.HostnameEntry[] entries = hostnames.asArray();
120+
assertEquals(2, entries.length);
121+
assertTrue(containsEntry(entries, "example.com", 80));
122+
assertTrue(containsEntry(entries, "test.com", 443));
123+
}
124+
125+
@Test
126+
public void testAddArrayWithExistingEntries() {
127+
hostnames.add("example.com", 80); // Initial entry
128+
Hostnames.HostnameEntry[] entriesToAdd = {
129+
new Hostnames.HostnameEntry("example.com", 80), // Same entry, should merge hits
130+
new Hostnames.HostnameEntry("test.com", 443) // New entry
131+
};
132+
entriesToAdd[0].incrementHits();
133+
entriesToAdd[1].incrementHits();
134+
135+
hostnames.addArray(entriesToAdd);
136+
137+
Hostnames.HostnameEntry[] entries = hostnames.asArray();
138+
assertEquals(2, entries.length);
139+
assertEquals(2, getHits(entries, "example.com", 80)); // Hits should be 2
140+
assertEquals(1, getHits(entries, "test.com", 443)); // Hits should be 1
141+
}
142+
143+
@Test
144+
public void testAddArrayExceedMaxEntries() {
145+
hostnames.add("example.com", 80);
146+
hostnames.add("test.com", 443);
147+
hostnames.add("localhost", 3000);
148+
149+
Hostnames.HostnameEntry[] entriesToAdd = {
150+
new Hostnames.HostnameEntry("newsite.com", 8080), // This should cause an eviction
151+
new Hostnames.HostnameEntry("example.com", 80) // Should merge hits
152+
};
153+
entriesToAdd[0].incrementHits();
154+
entriesToAdd[1].incrementHits(10);
155+
156+
hostnames.addArray(entriesToAdd);
157+
158+
Hostnames.HostnameEntry[] entries = hostnames.asArray();
159+
assertEquals(3, entries.length);
160+
assertFalse(containsEntry(entries, "test.com", 443)); // "test.com" should be evicted
161+
assertTrue(containsEntry(entries, "localhost", 3000));
162+
assertTrue(containsEntry(entries, "newsite.com", 8080));
163+
assertEquals(10, getHits(entries, "example.com", 80)); // Hits should be 2
164+
}
165+
166+
@Test
167+
public void testAddArrayWithZeroPort() {
168+
Hostnames.HostnameEntry[] entriesToAdd = {
169+
new Hostnames.HostnameEntry("example.com", 0),
170+
new Hostnames.HostnameEntry("test.com", 0)
171+
};
172+
173+
hostnames.addArray(entriesToAdd);
174+
175+
Hostnames.HostnameEntry[] entries = hostnames.asArray();
176+
assertEquals(2, entries.length);
177+
assertTrue(containsEntry(entries, "example.com", 0));
178+
assertTrue(containsEntry(entries, "test.com", 0));
179+
}
108180

109181
private boolean containsEntry(Hostnames.HostnameEntry[] entries, String hostname, int port) {
110182
for (Hostnames.HostnameEntry entry : entries) {
@@ -114,4 +186,12 @@ private boolean containsEntry(Hostnames.HostnameEntry[] entries, String hostname
114186
}
115187
return false;
116188
}
189+
private int getHits(Hostnames.HostnameEntry[] entries, String hostname, int port) {
190+
for (Hostnames.HostnameEntry entry : entries) {
191+
if (entry.getHostname().equals(hostname) && entry.getPort() == port) {
192+
return entry.getHits();
193+
}
194+
}
195+
return 0; // Not found
196+
}
117197
}

0 commit comments

Comments
 (0)