1111import java .util .stream .Stream ;
1212import com .fasterxml .jackson .databind .ObjectMapper ;
1313import io .sentrius .agent .analysis .agents .agents .ChatAgent ;
14+ import io .sentrius .agent .analysis .agents .verbs .AgentVerbs ;
15+ import io .sentrius .agent .analysis .agents .verbs .TerminalVerbs ;
1416import io .sentrius .agent .analysis .api .UserCommunicationService ;
17+ import io .sentrius .sso .core .exceptions .ZtatException ;
1518import io .sentrius .sso .core .services .agents .ZeroTrustClientService ;
19+ import io .sentrius .sso .genai .Message ;
1620import io .sentrius .sso .protobuf .Session ;
1721import lombok .RequiredArgsConstructor ;
1822import lombok .extern .slf4j .Slf4j ;
@@ -29,15 +33,20 @@ public class ChatWSHandler extends TextWebSocketHandler {
2933
3034 final UserCommunicationService userCommunicationService ;
3135 final ZeroTrustClientService zeroTrustClientService ;
36+ final TerminalVerbs terminalVerbs ;
37+ final AgentVerbs agentVerbs ;
3238 // Store active sessions, using session ID or a custom identifier
3339
3440
3541 private final ChatAgent chatAgent ;
3642
3743 @ Autowired
38- public ChatWSHandler (UserCommunicationService userCommunicationService , ZeroTrustClientService zeroTrustClientService , ChatAgent chatAgent ) {
44+ public ChatWSHandler (UserCommunicationService userCommunicationService , ZeroTrustClientService zeroTrustClientService ,
45+ TerminalVerbs terminalVerbs , AgentVerbs agentVerbs , ChatAgent chatAgent ) {
3946 this .userCommunicationService = userCommunicationService ;
4047 this .zeroTrustClientService = zeroTrustClientService ;
48+ this .terminalVerbs = terminalVerbs ;
49+ this .agentVerbs = agentVerbs ;
4150 this .chatAgent = chatAgent ;
4251 }
4352
@@ -85,6 +94,8 @@ public void afterConnectionEstablished(WebSocketSession session) throws Exceptio
8594 session .sendMessage (new TextMessage (
8695 base64Message
8796 ));
97+
98+ userCommunicationService .createSession (queryParams .get ("sessionId" ), session );
8899 }
89100
90101
@@ -100,7 +111,10 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
100111 Map <String , String > queryParams = parseQueryParams (uri .getQuery ());
101112 String sessionId = queryParams .get ("sessionId" );
102113
103- if (sessionId != null ) {
114+ var websocky = userCommunicationService .getSession (sessionId );
115+
116+ if (sessionId != null && websocky .isPresent ()) {
117+ var websocketCommunication = websocky .get ();
104118 log .info ("Received message from session ID: " + sessionId );
105119 // Handle the message (e.g., process or respond)
106120
@@ -113,7 +127,6 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
113127 Session .ChatMessage .parseFrom (messageBytes );
114128
115129 if (auditLog .getMessage ().equals ("heartbeat" )) {
116- log .info ("heartbeat" );
117130 return ;
118131 }
119132 var json = new ObjectMapper ().readTree (auditLog .getMessage ());
@@ -136,8 +149,25 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
136149 session .close ();
137150 }
138151 return ;
139- } else if ("heartbeat" .equals (auditLog .getMessage ())) {
152+ } else if ("user-message" .equals (json .get ("type" ).asText ())) {
153+ Message userMessage = Message .builder ().role ("user" ).content (json .get ("message" ).asText ()).build ();
140154 log .info ("Received heartbeat from session {}" , sessionId );
155+ var response = agentVerbs .interpretUserData (chatAgent .getAgentExecution (),
156+ websocketCommunication , userMessage );
157+ log .info ("Response: {}" , response );
158+ var newMessage = Session .ChatMessage .newBuilder ()
159+ .setMessage (String .format ("{\" type\" :\" user-message\" ,\" message\" :\" %s\" }" ,
160+ response .getResponseForUser ()))
161+ .setSender ("agent" )
162+ .setChatGroupId ("" )
163+ .setSessionId (Long .parseLong (websocketCommunication .getSessionId ()))
164+ .setTimestamp (System .currentTimeMillis ())
165+ .build ();
166+ messageBytes = newMessage .toByteArray ();
167+ String base64Message = Base64 .getEncoder ().encodeToString (messageBytes );
168+ session .sendMessage (new TextMessage (
169+ base64Message
170+ ));
141171 return ; // Ignore heartbeat messages
142172 } else {
143173 log .info ("Processing message: {}" , auditLog .getMessage ());
@@ -154,7 +184,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
154184 log .info ("Session ID not found in query parameters for message handling." );
155185 }
156186 }
157- }catch (Exception e ){
187+ }catch (Exception | ZtatException e ){
158188 throw new RuntimeException (e );
159189 }
160190 }
0 commit comments