Skip to content

Commit 82cb175

Browse files
Copilotphrocker
andcommitted
Fix private memory storage and retrieval with null userId and enable UI search (#284)
* Initial plan * Fix private memory storage and retrieval with null userId Co-authored-by: phrocker <[email protected]> * Fix potential NPE when both markings and userId are null Co-authored-by: phrocker <[email protected]> * Enable PRIVATE memory search from UI by passing agentId to hybridSearch Co-authored-by: phrocker <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: phrocker <[email protected]>
1 parent 4c068ca commit 82cb175

File tree

6 files changed

+133
-26
lines changed

6 files changed

+133
-26
lines changed

api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentMemoryController.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,13 @@ public ResponseEntity<List<AgentMemoryDTO>> hybridSearch(
426426

427427
String searchTerm = (String) searchRequest.get("searchTerm");
428428
String markingsFilter = (String) searchRequest.get("markings");
429+
String agentId = (String) searchRequest.get("agentId");
429430
Integer limit = (Integer) searchRequest.getOrDefault("limit", 10);
430431
// Default to 0 to let the service determine optimal threshold
431432
Double threshold = (Double) searchRequest.getOrDefault("threshold", 0.0);
432433

433-
log.debug("Hybrid search - term: {}, markings: {}, limit: {}, threshold: {}",
434-
searchTerm, markingsFilter, limit, threshold);
434+
log.debug("Hybrid search - term: {}, markings: {}, agentId: {}, limit: {}, threshold: {}",
435+
searchTerm, markingsFilter, agentId, limit, threshold);
435436

436437
try {
437438
var operatingUser = getOperatingUser(request,response);
@@ -449,7 +450,7 @@ public ResponseEntity<List<AgentMemoryDTO>> hybridSearch(
449450
AccessEvaluator evaluator = authorizations.isEmpty() ? null : AccessEvaluator.of(authorizations);
450451

451452
List<AgentMemory> results = vectorMemoryStore.hybridSearch(evaluator,
452-
searchTerm, markingsFilter, userId, limit, threshold);
453+
searchTerm, markingsFilter, userId, agentId, limit, threshold);
453454

454455
List<AgentMemoryDTO> responseDTOs = results.stream()
455456
.map(this::convertToDTO)
@@ -595,8 +596,9 @@ public ResponseEntity<Page<AgentMemoryDTO>> searchAgentMemory(
595596
AccessEvaluator evaluator = authorizations.isEmpty() ? null : AccessEvaluator.of(authorizations);
596597

597598
// Pass 0 as threshold to let the service determine the optimal threshold based on query
599+
// Pass agent parameter to enable access to PRIVATE memories without USER markings
598600
List<AgentMemory> results = vectorMemoryStore.hybridSearch(evaluator,
599-
content, markings, operatingUser.getUserId(), 10, 0);
601+
content, markings, operatingUser.getUserId(), agent, 10, 0);
600602

601603
Page<AgentMemoryDTO> responseDTOs = results.stream()
602604
.map(this::convertToDTO).collect(Collectors.collectingAndThen(

dataplane/src/main/java/io/sentrius/sso/core/services/agents/MemoryAccessControlService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ public boolean canAccessMemory(AgentMemory memory, AccessEvaluator evaluator, St
8484
log.debug("Creator access granted");
8585
return true;
8686
}
87+
88+
// PRIVATE memories without USER markings: allow agent to access its own memories
89+
// This handles the case where userId is null during storage but the memory belongs to the agent
90+
if ("PRIVATE".equalsIgnoreCase(memory.getClassification()) &&
91+
agentId != null && agentId.equals(memory.getAgentId()) &&
92+
(memory.getMarkings() == null || !memory.getMarkings().contains("USER:"))) {
93+
log.debug("PRIVATE agent memory access granted - agent accessing its own memory: {}", agentId);
94+
return true;
95+
}
8796

8897
// Check if memory can be shared with the agent
8998
if (agentId != null && memory.canBeSharedWith(agentId)) {

dataplane/src/main/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStore.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ public List<AgentMemory> findSimilarMemoriesForAgent(String queryText, String ag
173173
*/
174174
public List<AgentMemory> hybridSearch(
175175
AccessEvaluator evaluator, String searchTerm, String markingsFilter,
176-
String requestingUserId, int limit, double threshold) {
176+
String requestingUserId, String agentId, int limit, double threshold) {
177177

178-
log.info("Hybrid search - term: '{}', markings: {}, user: {}, threshold: {}",
179-
searchTerm, markingsFilter, requestingUserId, threshold);
178+
log.info("Hybrid search - term: '{}', markings: {}, user: {}, agent: {}, threshold: {}",
179+
searchTerm, markingsFilter, requestingUserId, agentId, threshold);
180180

181181
try {
182182
// Use a more lenient threshold if not explicitly provided
@@ -249,7 +249,7 @@ public List<AgentMemory> hybridSearch(
249249
.filter(m -> seen.add(m.getId())) // dedupe by ID
250250
.filter(m -> !m.isExpired())
251251
.filter(m -> !isExcludedMemoryKey(m.getMemoryKey())) // Exclude temporary lookup results
252-
.filter(m -> accessControlService.canAccessMemory(m, evaluator, requestingUserId, null, "READ"))
252+
.filter(m -> accessControlService.canAccessMemory(m, evaluator, requestingUserId, agentId, "READ"))
253253
.sorted((a, b) -> Double.compare(
254254
scores.getOrDefault(b.getId(), 0.0),
255255
scores.getOrDefault(a.getId(), 0.0)))

dataplane/src/test/java/io/sentrius/sso/core/services/agents/MemoryUserMarkingAccessControlTest.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,75 @@ void testCanAccessMemory_UserMarkingWithWhitespace_ShouldHandleCorrectly() {
248248
assertTrue(canAccess, "User should be able to access memory with USER marking even with whitespace");
249249
}
250250

251+
@Test
252+
void testCanAccessMemory_PrivateAgentMemoryWithoutUserMarking_ShouldAllowAgentAccess() {
253+
// Arrange
254+
String agentId = "agent-123";
255+
String userId = null; // No userId available (e.g., during agent initialization)
256+
AgentMemory memory = AgentMemory.builder()
257+
.memoryKey("test-memory")
258+
.memoryValue("test-value")
259+
.agentId(agentId)
260+
.classification("PRIVATE")
261+
.markings("CONVERSATION") // No USER marking
262+
.creatorUserId(userId)
263+
.build();
264+
265+
// Act
266+
boolean canAccess = accessControlService.canAccessMemory(memory, userId, agentId, "READ");
267+
268+
// Assert
269+
assertTrue(canAccess, "Agent should be able to access its own PRIVATE memory without USER marking");
270+
}
271+
272+
@Test
273+
void testCanAccessMemory_PrivateAgentMemoryWithoutUserMarking_ShouldDenyDifferentAgent() {
274+
// Arrange
275+
String agentId = "agent-123";
276+
String differentAgentId = "agent-456";
277+
String userId = null;
278+
AgentMemory memory = AgentMemory.builder()
279+
.memoryKey("test-memory")
280+
.memoryValue("test-value")
281+
.agentId(agentId)
282+
.classification("PRIVATE")
283+
.markings("CONVERSATION")
284+
.creatorUserId(userId)
285+
.build();
286+
287+
// Mock empty policies (since we check policies after agent check fails)
288+
when(policyRepository.findByIsActiveTrueOrderByPolicyName()).thenReturn(Collections.emptyList());
289+
290+
// Act
291+
boolean canAccess = accessControlService.canAccessMemory(memory, userId, differentAgentId, "READ");
292+
293+
// Assert
294+
assertFalse(canAccess, "Different agent should NOT be able to access another agent's PRIVATE memory");
295+
}
296+
297+
@Test
298+
void testCanAccessMemory_PrivateMemoryWithUserMarking_ShouldNotUseAgentFallback() {
299+
// Arrange
300+
String agentId = "agent-123";
301+
String userId = "user-123";
302+
String differentUserId = "user-456";
303+
AgentMemory memory = AgentMemory.builder()
304+
.memoryKey("test-memory")
305+
.memoryValue("test-value")
306+
.agentId(agentId)
307+
.classification("PRIVATE")
308+
.markings("CONVERSATION,USER:" + userId) // Has USER marking
309+
.creatorUserId(userId)
310+
.build();
311+
312+
// Act
313+
boolean canAccess = accessControlService.canAccessMemory(memory, differentUserId, agentId, "READ");
314+
315+
// Assert
316+
assertFalse(canAccess, "Agent access fallback should not apply when USER marking is present");
317+
}
318+
319+
251320
@Test
252321
void testCanAccessMemory_MultipleUserMarkings_ShouldDenyIfNoMatch() {
253322
// Arrange - This is an edge case that shouldn't normally happen but we should handle it

enterprise-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,20 @@ public void onApplicationEvent(final ApplicationReadyEvent event) {
357357
JsonNode value = memoryMeta.has("value") ?
358358
memoryMeta.get("value") : memoryMeta;
359359

360-
// Add userId to markings for privacy scoping
361-
String userId = agentExecution.getUser().getUserId();
362-
String enhancedMarkings = markings != null
363-
? markings + ",USER:" + userId
364-
: "USER:" + userId;
360+
// Add userId to markings for privacy scoping if userId is available
361+
String userId = agentExecution.getUser() != null
362+
? agentExecution.getUser().getUserId()
363+
: null;
364+
String enhancedMarkings;
365+
if (userId != null && !userId.isEmpty()) {
366+
enhancedMarkings = markings != null
367+
? markings + ",USER:" + userId
368+
: "USER:" + userId;
369+
} else {
370+
// If no userId, use markings as-is without USER scoping
371+
// Ensure we have at least an empty string to avoid NPE in split()
372+
enhancedMarkings = markings != null ? markings : "";
373+
}
365374

366375
agentClientService.storeMemory(agentExecution,
367376
agentExecutionContext.getAgentContext().getName(),
@@ -370,7 +379,7 @@ public void onApplicationEvent(final ApplicationReadyEvent event) {
370379
.memoryKey(memoryEntry.getKey())
371380
.memoryValue(value.toString())
372381
.classification(classification)
373-
.markings(enhancedMarkings.split(","))
382+
.markings(enhancedMarkings.isEmpty() ? new String[0] : enhancedMarkings.split(","))
374383
.conversationId(agentExecution.getCommunicationId())
375384
.build());
376385
log.info("Stored memory: {} with classification: {} and markings: {}",

enterprise-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,20 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
520520
JsonNode value = memoryMeta.has("value") ?
521521
memoryMeta.get("value") : memoryMeta;
522522

523-
// Add userId to markings for privacy scoping
524-
String userId = chatAgent.getAgentExecution().getUser().getUserId();
525-
String enhancedMarkings = markings != null
526-
? markings + ",USER:" + userId
527-
: "USER:" + userId;
523+
// Add userId to markings for privacy scoping if userId is available
524+
String userId = chatAgent.getAgentExecution().getUser() != null
525+
? chatAgent.getAgentExecution().getUser().getUserId()
526+
: null;
527+
String enhancedMarkings;
528+
if (userId != null && !userId.isEmpty()) {
529+
enhancedMarkings = markings != null
530+
? markings + ",USER:" + userId
531+
: "USER:" + userId;
532+
} else {
533+
// If no userId, use markings as-is without USER scoping
534+
// Ensure we have at least an empty string to avoid NPE in split()
535+
enhancedMarkings = markings != null ? markings : "";
536+
}
528537

529538
agentClientService.storeMemory(chatAgent.getAgentExecution(),
530539
websocketCommunication.getAgentExecutionContextDTO().getAgentContext().getName(),
@@ -533,7 +542,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
533542
.memoryKey(memoryEntry.getKey())
534543
.memoryValue(value.toString())
535544
.classification(classification)
536-
.markings(enhancedMarkings.split(","))
545+
.markings(enhancedMarkings.isEmpty() ? new String[0] : enhancedMarkings.split(","))
537546
.conversationId(chatAgent.getAgentExecution().getCommunicationId())
538547
.build());
539548
log.info("Stored memory: {} with classification: {} and markings: {}",
@@ -565,11 +574,20 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
565574
JsonNode value = memoryMeta.has("value") ?
566575
memoryMeta.get("value") : memoryMeta;
567576

568-
// Add userId to markings for privacy scoping
569-
String userId = chatAgent.getAgentExecution().getUser().getUserId();
570-
String enhancedMarkings = markings != null
571-
? markings + ",USER:" + userId
572-
: "USER:" + userId;
577+
// Add userId to markings for privacy scoping if userId is available
578+
String userId = chatAgent.getAgentExecution().getUser() != null
579+
? chatAgent.getAgentExecution().getUser().getUserId()
580+
: null;
581+
String enhancedMarkings;
582+
if (userId != null && !userId.isEmpty()) {
583+
enhancedMarkings = markings != null
584+
? markings + ",USER:" + userId
585+
: "USER:" + userId;
586+
} else {
587+
// If no userId, use markings as-is without USER scoping
588+
// Ensure we have at least an empty string to avoid NPE in split()
589+
enhancedMarkings = markings != null ? markings : "";
590+
}
573591

574592
agentClientService.storeMemory(chatAgent.getAgentExecution(),
575593
websocketCommunication.getAgentExecutionContextDTO().getAgentContext().getName(),
@@ -578,7 +596,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
578596
.memoryKey(memoryEntry.getKey())
579597
.memoryValue(value.toString())
580598
.classification(classification)
581-
.markings(enhancedMarkings.split(","))
599+
.markings(enhancedMarkings.isEmpty() ? new String[0] : enhancedMarkings.split(","))
582600
.conversationId(chatAgent.getAgentExecution().getCommunicationId())
583601
.build());
584602
log.info("Stored memory: {} with classification: {} and markings: {}",

0 commit comments

Comments
 (0)