Skip to content

Commit f8d2bbb

Browse files
committed
add jwt verification
1 parent 3150e3e commit f8d2bbb

File tree

110 files changed

+792
-2815
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+792
-2815
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ llm-proxy/target/**
4646
llm-proxy/target/
4747
llm-dataplane/target/**
4848
sentrius-llm-dataplane/target/**
49+
sentrius-llm-core/target/**
4950
llm-dataplane/target/
5051
sentrius-llm-dataplane/target/
52+
sentrius-llm-core/target/
5153
node/*
5254
node_modules/*
5355
api/node_modules/*

ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.sentrius.agent.analysis.agents.agents;
22

3+
import java.util.UUID;
34
import java.util.concurrent.TimeUnit;
45
import com.fasterxml.jackson.core.JsonProcessingException;
56
import com.fasterxml.jackson.databind.node.ArrayNode;
@@ -45,6 +46,8 @@ public class ChatAgent implements ApplicationListener<ApplicationReadyEvent> {
4546
private volatile boolean running = true;
4647
private Thread workerThread;
4748

49+
private AgentExecution agentExecution;
50+
4851
public ArrayNode promptAgent(AgentExecution execution) throws ZtatException {
4952
while(true){
5053
try {
@@ -79,14 +82,15 @@ public void onApplicationEvent(final ApplicationReadyEvent event) {
7982
var keyPair = agentKeyService.getKeyPair();
8083

8184
try {
85+
var agentName = agentConfigOptions.getNamePrefix() + "-" + UUID.randomUUID().toString();
8286
var base64PublicKey = agentKeyService.getBase64PublicKey(keyPair.getPublic());
83-
var agentRegistrationDTO = agentClientService.bootstrap(agentConfigOptions.getName(), base64PublicKey
87+
var agentRegistrationDTO = agentClientService.bootstrap(agentName, base64PublicKey
8488
, keyPair.getPublic().getAlgorithm());
8589

8690
var encryptedSecret = agentRegistrationDTO.getClientSecret();
8791
var decryptedSecret = agentKeyService.
8892
decryptWithPrivateKey(encryptedSecret, keyPair.getPrivate());
89-
keycloakService.createKeycloakClient(agentConfigOptions.getName(),
93+
keycloakService.createKeycloakClient(agentName,
9094
decryptedSecret);
9195

9296

@@ -100,22 +104,23 @@ public void onApplicationEvent(final ApplicationReadyEvent event) {
100104
final UserDTO user = UserDTO.builder()
101105
.username(zeroTrustClientService.getUsername())
102106
.build();
103-
var execution = agentExecutionService.getAgentExecution(user);
107+
agentExecution = agentExecutionService.getAgentExecution(user);
108+
104109
try {
105-
agentClientService.heartbeat(execution, execution.getUser().getUsername());
110+
agentClientService.heartbeat(agentExecution, agentExecution.getUser().getUsername());
106111
} catch (ZtatException e) {
107112
throw new RuntimeException(e);
108113
}
109114

110115
while(running) {
111116

112117
try {
113-
var register = zeroTrustClientService.registerAgent(execution);
118+
var register = zeroTrustClientService.registerAgent(agentExecution);
114119
log.info("Registered agent response: {}", register);
115120

116121
var ztat = JsonUtil.MAPPER.readValue(register, Ztat.class);
117-
execution.setZtatToken(ztat.getZtatToken());
118-
execution.setCommunicationId(ztat.getCommunicationId());
122+
agentExecution.setZtatToken(ztat.getZtatToken());
123+
agentExecution.setCommunicationId(ztat.getCommunicationId());
119124
break;
120125
}catch (Exception | ZtatException e) {
121126

@@ -135,7 +140,7 @@ public void onApplicationEvent(final ApplicationReadyEvent event) {
135140
try {
136141

137142
Thread.sleep(5_000);
138-
agentClientService.heartbeat(execution, execution.getUser().getUsername());
143+
agentClientService.heartbeat(agentExecution, agentExecution.getUser().getUsername());
139144
} catch (InterruptedException | ZtatException ex) {
140145
throw new RuntimeException(ex);
141146
}
@@ -153,4 +158,7 @@ public void shutdown() {
153158
}
154159
}
155160

161+
public AgentExecution getAgentExecution() {
162+
return agentExecution;
163+
}
156164
}

ai-agent/src/main/java/io/sentrius/agent/analysis/api/BotController.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public ResponseEntity<AgentStatus> getStatus() {
5555
@GetMapping("/describe")
5656
public ResponseEntity<AgentRegistrationDTO> describeAgent() {
5757
AgentRegistrationDTO dto = AgentRegistrationDTO.builder()
58-
.agentName(agentConfig.getName())
58+
.agentName(agentConfig.getNamePrefix())
5959
.agentPublicKey(agentKeyService.getKeyPair().getPublic().toString())
6060
.agentPublicKeyAlgo(agentKeyService.getKeyPair().getPublic().getAlgorithm())
6161
.agentType(agentConfig.getType())
@@ -67,7 +67,7 @@ public ResponseEntity<AgentRegistrationDTO> describeAgent() {
6767
public ResponseEntity<AgentRegistrationDTO> getRegistration
6868
() {
6969
AgentRegistrationDTO dto = AgentRegistrationDTO.builder()
70-
.agentName(agentConfig.getName())
70+
.agentName(agentConfig.getNamePrefix())
7171
.agentPublicKey(agentKeyService.getKeyPair().getPublic().toString())
7272
.agentPublicKeyAlgo(agentKeyService.getKeyPair().getPublic().getAlgorithm())
7373
.build();

ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,88 @@
66
import java.security.GeneralSecurityException;
77
import java.util.Base64;
88
import java.util.Map;
9+
import java.util.UUID;
910
import java.util.stream.Collectors;
1011
import java.util.stream.Stream;
12+
import com.fasterxml.jackson.databind.ObjectMapper;
13+
import io.sentrius.agent.analysis.agents.agents.ChatAgent;
1114
import io.sentrius.agent.analysis.api.UserCommunicationService;
15+
import io.sentrius.sso.core.services.agents.ZeroTrustClientService;
1216
import io.sentrius.sso.protobuf.Session;
1317
import lombok.RequiredArgsConstructor;
1418
import lombok.extern.slf4j.Slf4j;
19+
import org.springframework.beans.factory.annotation.Autowired;
1520
import org.springframework.stereotype.Component;
21+
import org.springframework.web.socket.CloseStatus;
1622
import org.springframework.web.socket.TextMessage;
1723
import org.springframework.web.socket.WebSocketSession;
1824
import org.springframework.web.socket.handler.TextWebSocketHandler;
1925

2026
@Slf4j
2127
@Component
22-
@RequiredArgsConstructor
2328
public class ChatWSHandler extends TextWebSocketHandler {
2429

25-
UserCommunicationService userCommunicationService;
30+
final UserCommunicationService userCommunicationService;
31+
final ZeroTrustClientService zeroTrustClientService;
2632
// Store active sessions, using session ID or a custom identifier
2733

2834

35+
private final ChatAgent chatAgent;
36+
37+
@Autowired
38+
public ChatWSHandler(UserCommunicationService userCommunicationService, ZeroTrustClientService zeroTrustClientService, ChatAgent chatAgent) {
39+
this.userCommunicationService = userCommunicationService;
40+
this.zeroTrustClientService = zeroTrustClientService;
41+
this.chatAgent = chatAgent;
42+
}
43+
2944
@Override
3045
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
31-
// Extract query parameters from the URI
46+
log.info("New connection established");
3247
URI uri = session.getUri();
33-
if (uri != null) {
34-
Map<String, String> queryParams = parseQueryParams(uri.getQuery());
35-
String sessionId = queryParams.get("sessionId");
36-
37-
48+
if (uri == null) {
49+
session.close(CloseStatus.BAD_DATA);
50+
return;
51+
}
3852

39-
if (sessionId != null) {
40-
// Store the WebSocket session using the session ID from the query parameter
41-
userCommunicationService.createSession(sessionId,session);
42-
log.info("New connection established, session ID: " + sessionId);
43-
// until we have another human on the other side we don't need a thread for this.
44-
//chatListenerService.startChatListener(sessionId, session);
53+
Map<String, String> queryParams = parseQueryParams(uri.getQuery());
54+
Long sessionId = Long.valueOf( queryParams.get("sessionId") );
55+
String chatGroupId = queryParams.get("chatGroupId");
56+
String ztatToken = queryParams.get("ztat");
4557

46-
} else {
47-
log.info("Session ID not found in query parameters.");
48-
session.close(); // Close the session if no valid session ID is provided
49-
}
50-
} else {
51-
log.info("No URI available for this session.");
52-
session.close(); // Close the session if URI is unavailable
58+
if (sessionId == null || ztatToken == null) {
59+
log.warn("Missing sessionId or ZTAT");
60+
session.close(CloseStatus.BAD_DATA);
61+
return;
5362
}
63+
64+
// Store session
65+
userCommunicationService.createSession(queryParams.get("sessionId"), session);
66+
log.info("Session {} created for incoming connection", sessionId);
67+
68+
// Generate and store nonce for this session
69+
String nonce = UUID.randomUUID().toString();
70+
session.getAttributes().put("ztatNonce", nonce);
71+
session.getAttributes().put("ztatToken", ztatToken);
72+
session.getAttributes().put("sessionId", sessionId);
73+
74+
// Send challenge to the client
75+
log.info("Sending challenge to client: {}", nonce);
76+
var challenge = Session.ChatMessage.newBuilder()
77+
.setMessage(String.format("{\"type\":\"challenge\",\"nonce\":\"%s\"}", nonce))
78+
.setSender("agent")
79+
.setChatGroupId(chatGroupId)
80+
.setSessionId(sessionId)
81+
.setTimestamp(System.currentTimeMillis())
82+
.build();
83+
byte[] messageBytes = challenge.toByteArray();
84+
String base64Message = Base64.getEncoder().encodeToString(messageBytes);
85+
session.sendMessage(new TextMessage(
86+
base64Message
87+
));
5488
}
5589

90+
5691
@Override
5792
protected void handleTextMessage(WebSocketSession session, TextMessage message)
5893
throws IOException, GeneralSecurityException {
@@ -70,16 +105,48 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
70105
// Handle the message (e.g., process or respond)
71106

72107

108+
var connection = userCommunicationService.getSession(sessionId);
73109
// Deserialize the protobuf message
74-
byte[] messageBytes = Base64.getDecoder().decode(message.getPayload());
75-
Session.ChatMessage auditLog =
76-
Session.ChatMessage.parseFrom(messageBytes);
77-
if (auditLog.getMessage().equals("heartbeat")){
78-
log.info("heartbeat");
79-
return;
80-
}
81110

82-
var connection = userCommunicationService.getSession(sessionId);
111+
byte[] messageBytes = Base64.getDecoder().decode(message.getPayload());
112+
Session.ChatMessage auditLog =
113+
Session.ChatMessage.parseFrom(messageBytes);
114+
115+
if (auditLog.getMessage().equals("heartbeat")) {
116+
log.info("heartbeat");
117+
return;
118+
}
119+
var json = new ObjectMapper().readTree(auditLog.getMessage());
120+
if ("challenge-response".equals(json.get("type").asText())) {
121+
String signature = json.get("signature").asText();
122+
String publicKey = json.get("publicKey").asText();
123+
String nonce = (String) session.getAttributes().get("ztatNonce");
124+
String ztat = (String) session.getAttributes().get("ztatToken");
125+
126+
boolean verified =
127+
zeroTrustClientService.verifyZtatChallenge(chatAgent.getAgentExecution(), ztat, nonce,
128+
signature,
129+
publicKey);
130+
131+
if (verified) {
132+
session.getAttributes().put("verified", true);
133+
log.info("ZTAT challenge verified for session {}", session.getId());
134+
} else {
135+
log.warn("ZTAT challenge failed for session {}", session.getId());
136+
session.close();
137+
}
138+
return;
139+
} else if ("heartbeat".equals(auditLog.getMessage())) {
140+
log.info("Received heartbeat from session {}", sessionId);
141+
return; // Ignore heartbeat messages
142+
} else {
143+
log.info("Processing message: {}", auditLog.getMessage());
144+
// Process the message as needed
145+
//chatAgent.handleChatMessage(sessionId, auditLog);
146+
}
147+
148+
149+
83150

84151

85152

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package io.sentrius.agent.analysis.api.websocket;
2+
3+
import lombok.RequiredArgsConstructor;
4+
import lombok.extern.slf4j.Slf4j;
5+
import org.springframework.beans.factory.annotation.Autowired;
6+
import org.springframework.http.HttpHeaders;
7+
import org.springframework.http.server.ServerHttpRequest;
8+
import org.springframework.http.server.ServerHttpResponse;
9+
import org.springframework.security.oauth2.jwt.Jwt;
10+
import org.springframework.security.oauth2.jwt.JwtDecoder;
11+
import org.springframework.security.oauth2.jwt.JwtException;
12+
import org.springframework.stereotype.Component;
13+
import org.springframework.web.socket.server.HandshakeInterceptor;
14+
import org.springframework.web.socket.WebSocketHandler;
15+
16+
import java.net.URI;
17+
import java.util.Arrays;
18+
import java.util.Map;
19+
20+
@Slf4j
21+
@Component
22+
public class JwtHandshakeInterceptor implements HandshakeInterceptor {
23+
24+
private final JwtDecoder jwtDecoder;
25+
26+
@Autowired
27+
public JwtHandshakeInterceptor(JwtDecoder jwtDecoder) {
28+
this.jwtDecoder = jwtDecoder;
29+
}
30+
31+
@Override
32+
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
33+
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
34+
log.info("Handshake attempt: {}", request.getURI());
35+
36+
URI uri = request.getURI();
37+
String query = uri.getQuery();
38+
String token = null;
39+
40+
if (query != null && query.contains("ztat=")) {
41+
token = Arrays.stream(query.split("&"))
42+
.filter(s -> s.startsWith("ztat="))
43+
.map(s -> s.substring("ztat=".length()))
44+
.findFirst()
45+
.orElse(null);
46+
}
47+
48+
log.info("Token from query: {}", token);
49+
50+
if (token != null) {
51+
try {
52+
Jwt jwt = jwtDecoder.decode(token);
53+
log.info("JWT decoded: {}", jwt.getClaims());
54+
attributes.put("jwt", jwt);
55+
return true;
56+
} catch (JwtException e) {
57+
log.warn("JWT validation failed: {}", e.getMessage());
58+
return false;
59+
}
60+
}
61+
62+
log.warn("No token found in query string.");
63+
return false;
64+
}
65+
66+
@Override
67+
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
68+
WebSocketHandler wsHandler, Exception exception) {
69+
log.info("After handshake.");
70+
}
71+
}
Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
11
package io.sentrius.agent.analysis.api.websocket;
22

33
import lombok.RequiredArgsConstructor;
4+
import lombok.extern.slf4j.Slf4j;
45
import org.springframework.beans.factory.annotation.Value;
56
import org.springframework.context.annotation.Configuration;
67
import org.springframework.web.socket.config.annotation.EnableWebSocket;
78
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
89
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
10+
import org.springframework.beans.factory.annotation.Autowired;
911

1012
@Configuration
1113
@EnableWebSocket
1214
@RequiredArgsConstructor
15+
@Slf4j
1316
public class WebSocketConfig implements WebSocketConfigurer {
1417

1518
@Value("${agent.listen.websocket:false}") // Default is false
1619
private boolean listenWebSocket;
1720

1821
private final ChatWSHandler chatWSHandler;
22+
23+
@Autowired
24+
private JwtHandshakeInterceptor jwtHandshakeInterceptor;
25+
1926
@Override
2027
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
2128
if (listenWebSocket) {
29+
log.info("WebSocket is enabled, registering handlers.");
2230
registry.addHandler(chatWSHandler, "/api/v1/chat/attach/subscribe")
23-
.setAllowedOriginPatterns("*")
24-
.withSockJS(); // SockJS fallback if needed
25-
31+
.setAllowedOriginPatterns("*");
2632
}
2733
}
28-
}
34+
}
35+

0 commit comments

Comments
 (0)