Skip to content

Commit 784706a

Browse files
committed
begin chat dialog
1 parent 1083877 commit 784706a

File tree

14 files changed

+2497
-376
lines changed

14 files changed

+2497
-376
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package io.sentrius.sso.controllers.api;
2+
3+
import java.security.GeneralSecurityException;
4+
import java.time.ZoneOffset;
5+
import java.util.List;
6+
import java.util.stream.Collectors;
7+
import io.sentrius.sso.core.utils.AccessUtil;
8+
import io.sentrius.sso.protobuf.Session.ChatMessage;
9+
import io.sentrius.sso.core.config.SystemOptions;
10+
import io.sentrius.sso.core.controllers.BaseController;
11+
import io.sentrius.sso.core.model.security.enums.SSHAccessEnum;
12+
import io.sentrius.sso.core.model.sessions.SessionLog;
13+
import io.sentrius.sso.core.repository.ChatLogRepository;
14+
import io.sentrius.sso.core.security.service.CryptoService;
15+
import io.sentrius.sso.core.services.ErrorOutputService;
16+
import io.sentrius.sso.core.services.UserService;
17+
import io.sentrius.sso.core.services.auditing.AuditService;
18+
import io.sentrius.sso.core.services.terminal.SessionTrackingService;
19+
import jakarta.servlet.http.HttpServletRequest;
20+
import jakarta.servlet.http.HttpServletResponse;
21+
import lombok.extern.slf4j.Slf4j;
22+
import org.springframework.http.ResponseEntity;
23+
import org.springframework.web.bind.annotation.GetMapping;
24+
import org.springframework.web.bind.annotation.RequestMapping;
25+
import org.springframework.web.bind.annotation.RequestParam;
26+
import org.springframework.web.bind.annotation.RestController;
27+
28+
@Slf4j
29+
@RestController
30+
@RequestMapping("/api/v1/chat")
31+
public class ChatApiController extends BaseController {
32+
private final AuditService auditService;
33+
final CryptoService cryptoService;
34+
final SessionTrackingService sessionTrackingService;
35+
final ChatLogRepository chatLogRepository;
36+
37+
public ChatApiController(
38+
UserService userService,
39+
SystemOptions systemOptions,
40+
ErrorOutputService errorOutputService,
41+
AuditService auditService,
42+
CryptoService cryptoService, SessionTrackingService sessionTrackingService, ChatLogRepository chatLogRepository
43+
) {
44+
super(userService, systemOptions, errorOutputService);
45+
this.auditService = auditService;
46+
this.cryptoService = cryptoService;
47+
this.sessionTrackingService = sessionTrackingService;
48+
this.chatLogRepository = chatLogRepository;
49+
}
50+
51+
public SessionLog createSession(@RequestParam String username, @RequestParam String ipAddress) {
52+
return auditService.createSession(username, ipAddress);
53+
}
54+
55+
@GetMapping("/history")
56+
public ResponseEntity<List<ChatMessage>> getChatHistory(
57+
HttpServletRequest request,
58+
HttpServletResponse response,
59+
@RequestParam(name="sessionId") String sessionIdEncrypted,
60+
@RequestParam(name="chatGroupId") String chatGroupIdEncrypted)
61+
throws GeneralSecurityException {
62+
63+
Long sessionId = Long.parseLong(cryptoService.decrypt(sessionIdEncrypted));
64+
65+
// Check if the user has access to this session
66+
var myConnectedSystem = sessionTrackingService.getConnectedSession(sessionId);
67+
68+
var user = getOperatingUser(request, response);
69+
70+
if (myConnectedSystem == null ||
71+
(
72+
!myConnectedSystem.getUser().getId().equals(user.getId()) &&
73+
!AccessUtil.canAccess(user, SSHAccessEnum.CAN_MANAGE_SYSTEMS))) {
74+
return ResponseEntity.status(403).body(null); // Forbidden access
75+
}
76+
77+
78+
String chatGroupId = cryptoService.decrypt(chatGroupIdEncrypted);
79+
List<ChatMessage> messages = chatLogRepository.findBySessionIdAndChatGroupId(sessionId, chatGroupId)
80+
.stream()
81+
.map(chatLog -> ChatMessage.newBuilder()
82+
.setSessionId(sessionId)
83+
.setChatGroupId(chatGroupId)
84+
.setSender(chatLog.getSender())
85+
.setMessage(chatLog.getMessage())
86+
.setTimestamp(chatLog.getMessageTimestamp().toEpochSecond(ZoneOffset.UTC)).build())
87+
.collect(Collectors.toList());
88+
89+
return ResponseEntity.ok(messages);
90+
}
91+
92+
93+
}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
2+
package io.sentrius.sso.websocket;
3+
4+
import java.io.IOException;
5+
import java.net.URI;
6+
import java.security.GeneralSecurityException;
7+
import java.sql.Timestamp;
8+
import java.util.Base64;
9+
import java.util.Map;
10+
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.stream.Collectors;
12+
import java.util.stream.Stream;
13+
import io.sentrius.sso.automation.auditing.Trigger;
14+
import io.sentrius.sso.automation.auditing.TriggerAction;
15+
import io.sentrius.sso.core.security.service.CryptoService;
16+
import io.sentrius.sso.core.services.metadata.TerminalSessionMetadataService;
17+
import io.sentrius.sso.core.services.terminal.SessionTrackingService;
18+
import io.sentrius.sso.protobuf.Session;
19+
import lombok.RequiredArgsConstructor;
20+
import lombok.extern.slf4j.Slf4j;
21+
import org.springframework.stereotype.Component;
22+
import org.springframework.web.socket.TextMessage;
23+
import org.springframework.web.socket.WebSocketSession;
24+
import org.springframework.web.socket.handler.TextWebSocketHandler;
25+
26+
@Slf4j
27+
@Component
28+
@RequiredArgsConstructor
29+
public class ChatWSHandler extends TextWebSocketHandler {
30+
31+
32+
final SessionTrackingService sessionTrackingService;
33+
final SshListenerService sshListenerService;
34+
final CryptoService cryptoService;
35+
final TerminalSessionMetadataService terminalSessionMetadataService;
36+
37+
38+
// Store active sessions, using session ID or a custom identifier
39+
private final ConcurrentHashMap<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
40+
41+
@Override
42+
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
43+
// Extract query parameters from the URI
44+
URI uri = session.getUri();
45+
if (uri != null) {
46+
Map<String, String> queryParams = parseQueryParams(uri.getQuery());
47+
String sessionId = queryParams.get("sessionId");
48+
String chatGropuId = queryParams.get("chatGroupId");
49+
50+
51+
52+
if (sessionId != null) {
53+
// Store the WebSocket session using the session ID from the query parameter
54+
sessions.put(sessionId, session);
55+
log.info("New connection established, session ID: " + sessionId);
56+
sshListenerService.startListeningToSshServer(sessionId, session);
57+
} else {
58+
log.trace("Session ID not found in query parameters.");
59+
session.close(); // Close the session if no valid session ID is provided
60+
}
61+
} else {
62+
log.trace("No URI available for this session.");
63+
session.close(); // Close the session if URI is unavailable
64+
}
65+
}
66+
67+
@Override
68+
protected void handleTextMessage(WebSocketSession session, TextMessage message)
69+
throws IOException, GeneralSecurityException {
70+
71+
// Extract query parameters from the URI again if needed
72+
URI uri = session.getUri();
73+
log.trace("got message {}", uri);
74+
try {
75+
if (uri != null) {
76+
Map<String, String> queryParams = parseQueryParams(uri.getQuery());
77+
String sessionId = queryParams.get("sessionId");
78+
79+
if (sessionId != null) {
80+
log.trace("Received message from session ID: " + sessionId);
81+
// Handle the message (e.g., process or respond)
82+
83+
84+
// Deserialize the protobuf message
85+
byte[] messageBytes = Base64.getDecoder().decode(message.getPayload());
86+
Session.TerminalMessage auditLog =
87+
Session.TerminalMessage.parseFrom(messageBytes);
88+
// Decrypt the session ID
89+
// var sessionIdStr = cryptoService.decrypt(sessionId);
90+
// var sessionIdLong = Long.parseLong(sessionIdStr);
91+
var lookupId = sessionId + "==";
92+
// Retrieve ConnectedSystem from your persistent map using the session ID
93+
var sys = sessionTrackingService.getEncryptedConnectedSession(lookupId);
94+
if (null != sys ) {
95+
boolean allNoAction = true;
96+
log.debug("**** Processing message for session ID: {} with {} actions", sessionId,
97+
sys.getSessionStartupActions().size());
98+
for (var action : sys.getSessionStartupActions()) {
99+
var trigger = action.onMessage(auditLog);
100+
if (trigger.get().getAction() == TriggerAction.JIT_ACTION) {
101+
allNoAction = false;
102+
// drop the message
103+
sys.getTerminalAuditor().setSessionTrigger(trigger.get());
104+
log.debug("**** Setting JIT Trigger: {}", trigger.get());
105+
sessionTrackingService.addSystemTrigger(sys, trigger.get());
106+
return;
107+
} else if (trigger.get().getAction() == TriggerAction.WARN_ACTION) {
108+
allNoAction = false;
109+
// send the message
110+
log.debug("**** Setting WARN Trigger: {}", trigger.get());
111+
sys.getTerminalAuditor().setSessionTrigger(trigger.get());
112+
sessionTrackingService.addSystemTrigger(sys, trigger.get());
113+
} else if (trigger.get().getAction() == TriggerAction.PROMPT_ACTION) {
114+
sessionTrackingService.addTrigger(sys, trigger.get());
115+
return;
116+
}
117+
}
118+
if (allNoAction && sys.getSessionStartupActions().size() > 0) {
119+
log.info("**** Setting NO_ACTION Trigger");
120+
var noActionTrigger = new Trigger(TriggerAction.NO_ACTION, "");
121+
sessionTrackingService.addSystemTrigger(sys, noActionTrigger);
122+
sys.getTerminalAuditor().setSessionTrigger(noActionTrigger);
123+
}
124+
125+
// Get the user's session and handle trigger if present
126+
sshListenerService.processTerminalMessage(sys, auditLog);
127+
}
128+
} else {
129+
log.trace("Session ID not found in query parameters for message handling.");
130+
}
131+
}
132+
}catch (Exception e ){
133+
e.printStackTrace();
134+
throw new RuntimeException(e);
135+
}
136+
}
137+
138+
@Override
139+
public void afterConnectionClosed(WebSocketSession session, org.springframework.web.socket.CloseStatus status) throws Exception {
140+
URI uri = session.getUri();
141+
if (uri != null) {
142+
Map<String, String> queryParams = parseQueryParams(uri.getQuery());
143+
String sessionId = queryParams.get("sessionId");
144+
145+
if (sessionId != null) {
146+
// Remove the session when connection is closed
147+
var lookupId = sessionId + "==";
148+
var sys = sessionTrackingService.getEncryptedConnectedSession(lookupId);
149+
if (null != sys){
150+
log.info("**** Closing session for {}", sys.getSession());
151+
terminalSessionMetadataService.getSessionBySessionLog(sys.getSession()).ifPresent(sessionMetadata -> {
152+
sessionMetadata.setEndTime(new Timestamp(System.currentTimeMillis()));
153+
sessionMetadata.setSessionStatus("CLOSED");
154+
terminalSessionMetadataService.saveSession(sessionMetadata);
155+
});
156+
}
157+
158+
sessions.remove(sessionId);
159+
sshListenerService.removeSession(sessionId);
160+
161+
log.info("Connection closed, session ID: " + sessionId);
162+
}
163+
}
164+
}
165+
166+
// Utility method to parse query parameters
167+
private Map<String, String> parseQueryParams(String query) {
168+
if (query == null || query.isEmpty()) {
169+
return Map.of();
170+
}
171+
return Stream.of(query.split("&"))
172+
.map(param -> param.split("="))
173+
.collect(Collectors.toMap(
174+
param -> param[0],
175+
param -> param.length > 1 ? param[1] : ""
176+
));
177+
}
178+
179+
// Utility method to send a message to a specific session
180+
public void sendMessageToSession(String sessionId, String message) {
181+
WebSocketSession session = sessions.get(sessionId);
182+
if (session != null && session.isOpen()) {
183+
try {
184+
session.sendMessage(new TextMessage(message));
185+
} catch (IOException e) {
186+
System.err.println("Error sending message to session " + sessionId);
187+
e.printStackTrace();
188+
}
189+
} else {
190+
System.err.println("Session not found or already closed: " + sessionId);
191+
}
192+
}
193+
}

api/src/main/java/io/sentrius/sso/websocket/WebSocketConfig.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public class WebSocketConfig implements WebSocketConfigurer {
1313

1414
private final TerminalWSHandler customWebSocketHandler;
1515
private final AuditSocketHandler auditSocketHandler;
16+
private final ChatWSHandler chatWSHandler;
1617
@Override
1718
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
1819
registry.addHandler(customWebSocketHandler, "/api/v1/ssh/terminal/subscribe")
@@ -21,6 +22,9 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
2122
registry.addHandler(auditSocketHandler, "/api/v1/audit/attach/subscribe")
2223
.setAllowedOriginPatterns("*")
2324
.withSockJS(); // SockJS fallback if needed
25+
registry.addHandler(chatWSHandler, "/api/v1/chat/attach/subscribe")
26+
.setAllowedOriginPatterns("*")
27+
.withSockJS(); // SockJS fallback if needed
2428

2529
}
2630
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CREATE TABLE IF NOT EXISTS chat_log (
2+
id BIGSERIAL PRIMARY KEY,
3+
session_id BIGINT NOT NULL,
4+
chat_group_id VARCHAR NOT NULL, -- Unique identifier for different chat dialogs within the session
5+
instance_id INTEGER,
6+
sender VARCHAR NOT NULL, -- username or system (e.g., AI agent)
7+
message TEXT NOT NULL,
8+
message_tm TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
9+
FOREIGN KEY (session_id) REFERENCES session_log(id) ON DELETE CASCADE
10+
);

0 commit comments

Comments
 (0)