Skip to content

Commit 30eee22

Browse files
Copilotphrocker
andauthored
Fix trust scores showing static value by persisting policy violation events (#171)
* Initial plan * Add PolicyViolationEvent persistence for accurate trust score calculation - Created PolicyViolationEvent entity to persist policy violation events - Created PolicyViolationEventType enum for violation event types - Created PolicyViolationEventRepository for data access - Created PolicyViolationEventService for managing violation events - Modified ZeroTrustAccessTokenService to record violations on approve/deny - Modified TrustEvaluationService to use persistent incident counts - Added comprehensive tests for PolicyViolationEventService Co-authored-by: phrocker <[email protected]> * Address code review feedback: use constructor injection - Convert field injection to constructor injection in ZeroTrustAccessTokenService - Convert field injection to constructor injection in TrustEvaluationService - All tests pass Co-authored-by: phrocker <[email protected]> * Add Flyway migration V39 for policy_violation_events table Creates policy_violation_events table with columns matching the PolicyViolationEvent JPA entity and appropriate indexes for efficient querying by entity_id, timestamp, and approval status. Co-authored-by: phrocker <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: phrocker <[email protected]>
1 parent 25f596b commit 30eee22

File tree

8 files changed

+839
-16
lines changed

8 files changed

+839
-16
lines changed

analytics/src/main/java/io/sentrius/agent/analysis/agents/trust/TrustEvaluationService.java

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import io.sentrius.sso.core.services.ATPLPolicyService;
1313
import io.sentrius.sso.core.services.UserService;
1414
import io.sentrius.sso.core.services.trust.AgentTrustScoreService;
15+
import io.sentrius.sso.core.services.trust.PolicyViolationEventService;
1516
import io.sentrius.sso.core.trust.*;
1617
import io.sentrius.sso.provenance.ProvenanceEvent;
1718
import lombok.extern.slf4j.Slf4j;
@@ -37,32 +38,33 @@ public class TrustEvaluationService {
3738
private final AgentTrustScoreService trustScoreService;
3839
private final ATPLPolicyService atplPolicyService;
3940
private final UserService userService;
40-
41-
@Autowired(required = false)
42-
private LLMGuidedSchedulerService llmScheduler;
41+
private final LLMGuidedSchedulerService llmScheduler;
42+
private final io.sentrius.sso.core.services.feedback.RLHFFeedbackService rlhfFeedbackService;
43+
private final PolicyViolationEventService policyViolationEventService;
4344

4445
private final Map<String, List<ProvenanceEvent>> provenanceCache = new ConcurrentHashMap<>();
4546
private final Map<String, Integer> incidentTracker = new ConcurrentHashMap<>();
4647

47-
// Optional RLHF service - may not be available in all contexts
48-
private final io.sentrius.sso.core.services.feedback.RLHFFeedbackService rlhfFeedbackService;
49-
48+
@Autowired
5049
public TrustEvaluationService(
5150
AgentHeartbeatRepository heartbeatRepository,
5251
AgentCommunicationRepository communicationRepository,
5352
SessionLogRepository sessionLogRepository,
5453
AgentTrustScoreService trustScoreService,
5554
ATPLPolicyService atplPolicyService,
5655
UserService userService,
57-
@org.springframework.beans.factory.annotation.Autowired(required = false)
58-
io.sentrius.sso.core.services.feedback.RLHFFeedbackService rlhfFeedbackService) {
56+
@Autowired(required = false) LLMGuidedSchedulerService llmScheduler,
57+
@Autowired(required = false) io.sentrius.sso.core.services.feedback.RLHFFeedbackService rlhfFeedbackService,
58+
@Autowired(required = false) PolicyViolationEventService policyViolationEventService) {
5959
this.heartbeatRepository = heartbeatRepository;
6060
this.communicationRepository = communicationRepository;
6161
this.sessionLogRepository = sessionLogRepository;
6262
this.trustScoreService = trustScoreService;
6363
this.atplPolicyService = atplPolicyService;
6464
this.userService = userService;
65+
this.llmScheduler = llmScheduler;
6566
this.rlhfFeedbackService = rlhfFeedbackService;
67+
this.policyViolationEventService = policyViolationEventService;
6668
}
6769

6870
@Scheduled(fixedRate = 300000, initialDelay = 60000)
@@ -193,7 +195,7 @@ private AgentContext buildHumanUserContext(String userId, String username) {
193195
List<SessionLog> userSessions = sessionLogRepository.findByUsername(username);
194196

195197
int priorRuns = calculatePriorSessions(userSessions);
196-
int incidentCount = incidentTracker.getOrDefault(userId, 0);
198+
int incidentCount = getIncidentCount(userId);
197199

198200
// Human users are verified through Keycloak authentication
199201
List<ProvenanceEvent> events = provenanceCache.getOrDefault(userId, Collections.emptyList());
@@ -236,7 +238,7 @@ private Set<String> extractUserTags(String username) {
236238
private AgentContext buildAgentContext(String agentId, String agentName) {
237239
Optional<AgentHeartbeat> heartbeatOpt = heartbeatRepository.findByAgentId(agentId);
238240
int priorRuns = calculatePriorRuns(agentId);
239-
int incidentCount = incidentTracker.getOrDefault(agentId, 0);
241+
int incidentCount = getIncidentCount(agentId);
240242

241243
boolean enclaveVerified = heartbeatOpt
242244
.map(hb -> hb.getStatus() != null && hb.getStatus().contains("verified"))
@@ -263,6 +265,19 @@ private AgentContext buildAgentContext(String agentId, String agentName) {
263265
.build();
264266
}
265267

268+
/**
269+
* Get the incident count for an entity from the persistent store if available,
270+
* otherwise fall back to the in-memory tracker.
271+
*/
272+
private int getIncidentCount(String entityId) {
273+
// First try to get from persistent store (policy violation events)
274+
if (policyViolationEventService != null) {
275+
return policyViolationEventService.getIncidentCount(entityId);
276+
}
277+
// Fall back to in-memory tracker
278+
return incidentTracker.getOrDefault(entityId, 0);
279+
}
280+
266281
private int calculatePriorRuns(String agentId) {
267282
LocalDateTime thirtyDaysAgo = LocalDateTime.now().minusDays(30);
268283
return (int) heartbeatRepository.findAll().stream()
@@ -317,14 +332,56 @@ public void cacheProvenanceEvent(ProvenanceEvent event) {
317332
}
318333
}
319334

335+
/**
336+
* Record an incident for an entity. This is now primarily for legacy/manual incident tracking.
337+
* For policy violations, use the PolicyViolationEventService directly.
338+
*/
320339
public void recordIncident(String agentId) {
321340
incidentTracker.merge(agentId, 1, Integer::sum);
322-
log.warn("Incident recorded for agent: {}. Total incidents: {}",
341+
log.warn("Incident recorded for agent: {}. Total in-memory incidents: {}",
323342
agentId, incidentTracker.get(agentId));
324343
}
325344

345+
/**
346+
* Record a policy violation incident that will affect trust scores.
347+
* This persists the violation to the database for accurate trust score calculation.
348+
*/
349+
public void recordPolicyViolation(String entityId, String entityName, String endpoint, boolean approved, String approverId) {
350+
if (policyViolationEventService != null) {
351+
if (approved) {
352+
policyViolationEventService.recordZtatApproval(
353+
entityId, entityName, endpoint, null, approverId, null,
354+
"Policy violation recorded via TrustEvaluationService"
355+
);
356+
} else {
357+
policyViolationEventService.recordZtatDenial(
358+
entityId, entityName, endpoint, null, approverId, null,
359+
"Policy violation recorded via TrustEvaluationService"
360+
);
361+
}
362+
log.info("Policy violation recorded for entity {}: endpoint={}, approved={}",
363+
entityId, endpoint, approved);
364+
} else {
365+
// Fall back to in-memory tracking if persistent service is not available
366+
if (!approved) {
367+
recordIncident(entityId);
368+
}
369+
}
370+
}
371+
372+
/**
373+
* Clear incidents for an entity. Note: this only clears the in-memory tracker.
374+
* Persistent policy violations cannot be cleared (they are part of the audit trail).
375+
*/
326376
public void clearIncidents(String agentId) {
327377
incidentTracker.put(agentId, 0);
328-
log.info("Incidents cleared for agent: {}", agentId);
378+
log.info("In-memory incidents cleared for agent: {}", agentId);
379+
}
380+
381+
/**
382+
* Get the total incident count for an entity (from both persistent and in-memory stores).
383+
*/
384+
public int getTotalIncidentCount(String entityId) {
385+
return getIncidentCount(entityId);
329386
}
330387
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
-- Create policy_violation_events table for tracking policy violation events
2+
-- These events are used to calculate behavior scores in trust evaluations
3+
CREATE TABLE IF NOT EXISTS policy_violation_events (
4+
id BIGSERIAL PRIMARY KEY,
5+
entity_id VARCHAR(255) NOT NULL,
6+
entity_name VARCHAR(255),
7+
event_type VARCHAR(50) NOT NULL,
8+
approved BOOLEAN NOT NULL,
9+
endpoint VARCHAR(1024),
10+
policy_id VARCHAR(255),
11+
approver_id VARCHAR(255),
12+
ztat_request_id BIGINT,
13+
description TEXT,
14+
timestamp TIMESTAMP NOT NULL
15+
);
16+
17+
-- Create indexes for better query performance
18+
CREATE INDEX IF NOT EXISTS idx_pv_entity_id_timestamp
19+
ON policy_violation_events(entity_id, timestamp DESC);
20+
CREATE INDEX IF NOT EXISTS idx_pv_timestamp
21+
ON policy_violation_events(timestamp DESC);
22+
CREATE INDEX IF NOT EXISTS idx_pv_entity_id
23+
ON policy_violation_events(entity_id);
24+
CREATE INDEX IF NOT EXISTS idx_pv_approved
25+
ON policy_violation_events(approved);
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package io.sentrius.sso.core.model.trust;
2+
3+
import java.time.LocalDateTime;
4+
import jakarta.persistence.*;
5+
import lombok.*;
6+
7+
/**
8+
* Records policy violation events when an agent/user tries to access an endpoint
9+
* outside their policy, and the resulting approval/denial decision.
10+
* These events are used to calculate behavior scores in trust evaluations.
11+
*/
12+
@Entity
13+
@Table(name = "policy_violation_events", indexes = {
14+
@Index(name = "idx_pv_entity_id_timestamp", columnList = "entity_id,timestamp"),
15+
@Index(name = "idx_pv_timestamp", columnList = "timestamp")
16+
})
17+
@Getter
18+
@Setter
19+
@Builder
20+
@NoArgsConstructor
21+
@AllArgsConstructor
22+
public class PolicyViolationEvent {
23+
24+
@Id
25+
@GeneratedValue(strategy = GenerationType.IDENTITY)
26+
private Long id;
27+
28+
/**
29+
* The ID of the entity (agent or user) that attempted the policy violation
30+
*/
31+
@Column(name = "entity_id", nullable = false)
32+
private String entityId;
33+
34+
/**
35+
* The name of the entity
36+
*/
37+
@Column(name = "entity_name")
38+
private String entityName;
39+
40+
/**
41+
* The type of violation event
42+
*/
43+
@Enumerated(EnumType.STRING)
44+
@Column(name = "event_type", nullable = false, length = 50)
45+
private PolicyViolationEventType eventType;
46+
47+
/**
48+
* Whether the violation was approved by a supervisor
49+
*/
50+
@Column(name = "approved", nullable = false)
51+
private Boolean approved;
52+
53+
/**
54+
* The endpoint that was accessed outside the policy
55+
*/
56+
@Column(name = "endpoint")
57+
private String endpoint;
58+
59+
/**
60+
* The policy ID that was violated
61+
*/
62+
@Column(name = "policy_id")
63+
private String policyId;
64+
65+
/**
66+
* The ID of the user who approved/denied the violation
67+
*/
68+
@Column(name = "approver_id")
69+
private String approverId;
70+
71+
/**
72+
* The ZTAT request ID associated with this event
73+
*/
74+
@Column(name = "ztat_request_id")
75+
private Long ztatRequestId;
76+
77+
/**
78+
* Description or notes about the violation
79+
*/
80+
@Column(name = "description", columnDefinition = "TEXT")
81+
private String description;
82+
83+
/**
84+
* When this event occurred
85+
*/
86+
@Column(name = "timestamp", nullable = false)
87+
private LocalDateTime timestamp;
88+
89+
@PrePersist
90+
protected void onCreate() {
91+
if (timestamp == null) {
92+
timestamp = LocalDateTime.now();
93+
}
94+
}
95+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package io.sentrius.sso.core.model.trust;
2+
3+
/**
4+
* Types of policy violation events that can be recorded.
5+
*/
6+
public enum PolicyViolationEventType {
7+
/**
8+
* Agent/user accessed an endpoint outside their policy and was approved
9+
*/
10+
OUT_OF_POLICY_ACCESS_APPROVED,
11+
12+
/**
13+
* Agent/user accessed an endpoint outside their policy and was denied
14+
*/
15+
OUT_OF_POLICY_ACCESS_DENIED,
16+
17+
/**
18+
* A ZTAT (Zero Trust Access Token) request was approved
19+
*/
20+
ZTAT_REQUEST_APPROVED,
21+
22+
/**
23+
* A ZTAT request was denied
24+
*/
25+
ZTAT_REQUEST_DENIED,
26+
27+
/**
28+
* An OPS JIT request was approved
29+
*/
30+
OPS_JIT_APPROVED,
31+
32+
/**
33+
* An OPS JIT request was denied
34+
*/
35+
OPS_JIT_DENIED
36+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package io.sentrius.sso.core.repository.trust;
2+
3+
import io.sentrius.sso.core.model.trust.PolicyViolationEvent;
4+
import io.sentrius.sso.core.model.trust.PolicyViolationEventType;
5+
import org.springframework.data.jpa.repository.JpaRepository;
6+
import org.springframework.data.jpa.repository.Query;
7+
import org.springframework.data.repository.query.Param;
8+
import org.springframework.stereotype.Repository;
9+
10+
import java.time.LocalDateTime;
11+
import java.util.List;
12+
13+
@Repository
14+
public interface PolicyViolationEventRepository extends JpaRepository<PolicyViolationEvent, Long> {
15+
16+
/**
17+
* Find all policy violation events for an entity ordered by timestamp descending
18+
*/
19+
List<PolicyViolationEvent> findByEntityIdOrderByTimestampDesc(String entityId);
20+
21+
/**
22+
* Find policy violation events for an entity within a time range
23+
*/
24+
List<PolicyViolationEvent> findByEntityIdAndTimestampBetweenOrderByTimestampDesc(
25+
String entityId, LocalDateTime start, LocalDateTime end);
26+
27+
/**
28+
* Count denied violations (incidents) for an entity since a given time
29+
*/
30+
@Query("SELECT COUNT(e) FROM PolicyViolationEvent e WHERE e.entityId = :entityId " +
31+
"AND e.approved = false AND e.timestamp >= :since")
32+
long countDeniedViolations(@Param("entityId") String entityId, @Param("since") LocalDateTime since);
33+
34+
/**
35+
* Count approved violations for an entity since a given time
36+
*/
37+
@Query("SELECT COUNT(e) FROM PolicyViolationEvent e WHERE e.entityId = :entityId " +
38+
"AND e.approved = true AND e.timestamp >= :since")
39+
long countApprovedViolations(@Param("entityId") String entityId, @Param("since") LocalDateTime since);
40+
41+
/**
42+
* Count all violations (both approved and denied) for an entity since a given time
43+
*/
44+
@Query("SELECT COUNT(e) FROM PolicyViolationEvent e WHERE e.entityId = :entityId " +
45+
"AND e.timestamp >= :since")
46+
long countAllViolations(@Param("entityId") String entityId, @Param("since") LocalDateTime since);
47+
48+
/**
49+
* Find violations by event type for an entity
50+
*/
51+
List<PolicyViolationEvent> findByEntityIdAndEventTypeOrderByTimestampDesc(
52+
String entityId, PolicyViolationEventType eventType);
53+
54+
/**
55+
* Find recent violations across all entities
56+
*/
57+
@Query("SELECT e FROM PolicyViolationEvent e WHERE e.timestamp >= :since ORDER BY e.timestamp DESC")
58+
List<PolicyViolationEvent> findRecentViolations(@Param("since") LocalDateTime since);
59+
}

0 commit comments

Comments
 (0)