Skip to content

Commit f8f2957

Browse files
Copilotphrocker
andcommitted
Cleanup SSH proxy implementation and add comprehensive tests
Co-authored-by: phrocker <[email protected]>
1 parent 20ee269 commit f8f2957

File tree

8 files changed

+943
-67
lines changed

8 files changed

+943
-67
lines changed

ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/controllers/RefreshController.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,46 @@
44
import io.sentrius.sso.core.controllers.BaseController;
55
import io.sentrius.sso.core.services.ErrorOutputService;
66
import io.sentrius.sso.core.services.UserService;
7+
import io.sentrius.sso.sshproxy.service.SshProxyServerService;
8+
import lombok.extern.slf4j.Slf4j;
9+
import org.springframework.http.ResponseEntity;
10+
import org.springframework.web.bind.annotation.PostMapping;
11+
import org.springframework.web.bind.annotation.RequestMapping;
12+
import org.springframework.web.bind.annotation.RestController;
713

14+
/**
15+
* REST controller for SSH proxy management operations.
16+
*/
17+
@Slf4j
18+
@RestController
19+
@RequestMapping("/api/ssh-proxy")
820
public class RefreshController extends BaseController {
9-
protected RefreshController(
10-
UserService userService, SystemOptions systemOptions,
11-
ErrorOutputService errorOutputService
21+
22+
private final SshProxyServerService sshProxyServerService;
23+
24+
public RefreshController(
25+
UserService userService,
26+
SystemOptions systemOptions,
27+
ErrorOutputService errorOutputService,
28+
SshProxyServerService sshProxyServerService
1229
) {
1330
super(userService, systemOptions, errorOutputService);
31+
this.sshProxyServerService = sshProxyServerService;
32+
}
33+
34+
/**
35+
* Refreshes the SSH proxy server host groups configuration.
36+
*/
37+
@PostMapping("/refresh")
38+
public ResponseEntity<String> refreshHostGroups() {
39+
try {
40+
log.info("Refreshing SSH proxy host groups configuration");
41+
sshProxyServerService.refreshHostGroups();
42+
return ResponseEntity.ok("SSH proxy host groups refreshed successfully");
43+
} catch (Exception e) {
44+
log.error("Failed to refresh SSH proxy host groups", e);
45+
return ResponseEntity.internalServerError()
46+
.body("Failed to refresh host groups: " + e.getMessage());
47+
}
1448
}
1549
}

ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShell.java

Lines changed: 8 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ public void start(ChannelSession channel, Environment env) throws IOException {
128128

129129
initializeHostSystemSelection();
130130

131-
var connectedSysteem = connect(user, selectedHostSystem.getHostGroups().get(0), selectedHostSystem.getId());
131+
var connectedSystem = connect(user, selectedHostSystem.getHostGroups().get(0), selectedHostSystem.getId());
132132
sendWelcomeMessage();
133-
startShellLoop(connectedSysteem);
133+
startShellLoop(connectedSystem);
134134
} catch (Exception e) {
135135
log.error("Failed to initialize SSH proxy session", e);
136136
callback.onExit(1, "Failed to initialize session");
@@ -251,78 +251,24 @@ private void startShellLoop(ConnectedSystem connectedSystem) throws GeneralSecur
251251
byte b = buffer[i];
252252
char c = (char) b;
253253

254-
255-
/*
256-
if (c == '\r' || c == '\n') {
257-
// Command completed
258-
String command = commandBuffer.toString().trim();
259-
260-
commandBuffer.setLength(0);
261-
out.write("\r\n".getBytes());
262-
sendPrompt();
263-
264-
265-
266-
auditLog.setKeycode(c);
267-
auditLog.setCommand(command);
268-
getSshListenerService().processTerminalMessage(connectedSystem,
269-
auditLog.build());
270-
auditLog =
271-
Session.TerminalMessage.newBuilder();
272-
} else if (c == 3) { // Ctrl+C
273-
auditLog.setKeycode(c);
274-
auditLog =
275-
Session.TerminalMessage.newBuilder();
276-
getSshListenerService().processTerminalMessage(connectedSystem,
277-
auditLog.build());
278-
terminalResponseService.sendMessage("^C\r\n", out);
279-
commandBuffer.setLength(0);
280-
} else if (c == 127 || c == 8) { // Backspace
281-
if (commandBuffer.length() > 0) {
282-
commandBuffer.setLength(commandBuffer.length() - 1);
283-
// out.write("\b \b".getBytes());
284-
}
285-
auditLog.setKeycode(c);
286-
auditLog =
287-
Session.TerminalMessage.newBuilder();
288-
getSshListenerService().processTerminalMessage(connectedSystem,
289-
auditLog.build());
290-
} else if (c >= 32 && c <= 126) { //+ Printable characters
291-
commandBuffer.append(c);
292-
auditLog.setCommand(String.valueOf(c));
293-
auditLog =
294-
Session.TerminalMessage.newBuilder();
295-
getSshListenerService().processTerminalMessage(connectedSystem,
296-
auditLog.build());
297-
// out.write(b);
298-
}
299-
300-
*/
301-
// Ignore other control characters for now
254+
// Process input character and send audit log
302255
if (c >= 32 && c <= 126) {
256+
// Printable characters
303257
auditLog.setCommand(String.valueOf(c));
304258
auditLog.setType(Session.MessageType.PROMPT_DATA);
305259
auditLog.setKeycode(-1);
306260
getSshListenerService().processTerminalMessage(connectedSystem,
307261
auditLog.build());
308-
auditLog =
309-
Session.TerminalMessage.newBuilder();
310-
// out.write(b);
311-
}else {
262+
auditLog = Session.TerminalMessage.newBuilder();
263+
} else {
264+
// Control characters and special keys
312265
auditLog.setKeycode(c);
313266
auditLog.setType(Session.MessageType.PROMPT_DATA);
314-
315-
316267
getSshListenerService().processTerminalMessage(connectedSystem,
317268
auditLog.build());
318-
auditLog =
319-
Session.TerminalMessage.newBuilder();
320-
// out.write(b);
321-
269+
auditLog = Session.TerminalMessage.newBuilder();
322270
}
323271
}
324-
325-
/// getSshListenerService().processTerminalMessage(connectedSystem, auditLog);
326272
}
327273

328274
} catch (IOException e) {

ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/service/HostSystemSelectionService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ public Optional<HostSystem> getDefaultHostSystem() {
8383
if (!hostSystems.isEmpty()) {
8484
HostSystem defaultHost = hostSystems.get(0);
8585
Hibernate.initialize(defaultHost.getHostGroups());
86-
for(HostGroup gropu : defaultHost.getHostGroups()) {
87-
Hibernate.initialize(gropu.getRules());
86+
for(HostGroup group : defaultHost.getHostGroups()) {
87+
Hibernate.initialize(group.getRules());
8888
}
8989
log.info("Using default HostSystem: {} ({}:{})",
9090
defaultHost.getDisplayName(), defaultHost.getHost(), defaultHost.getPort());
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package io.sentrius.sso.sshproxy.config;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.springframework.boot.test.context.SpringBootTest;
5+
import org.springframework.test.context.TestPropertySource;
6+
7+
import static org.junit.jupiter.api.Assertions.*;
8+
9+
@SpringBootTest(classes = {SshProxyConfig.class})
10+
@TestPropertySource(properties = {
11+
"sentrius.ssh-proxy.enabled=false"
12+
})
13+
class SshProxyConfigTest {
14+
15+
@Test
16+
void testDefaultValues() {
17+
SshProxyConfig config = new SshProxyConfig();
18+
19+
assertEquals(2222, config.getPort());
20+
assertEquals("/tmp/hostkey.ser", config.getHostKeyPath());
21+
assertTrue(config.isEnabled());
22+
assertEquals(100, config.getMaxConcurrentSessions());
23+
24+
assertNotNull(config.getConnection());
25+
assertEquals(30000, config.getConnection().getConnectionTimeout());
26+
assertEquals(60000, config.getConnection().getKeepAliveInterval());
27+
assertEquals(3, config.getConnection().getMaxRetries());
28+
}
29+
30+
@Test
31+
void testSettersAndGetters() {
32+
SshProxyConfig config = new SshProxyConfig();
33+
34+
config.setPort(2223);
35+
config.setHostKeyPath("/custom/path/hostkey.ser");
36+
config.setEnabled(false);
37+
config.setMaxConcurrentSessions(200);
38+
39+
assertEquals(2223, config.getPort());
40+
assertEquals("/custom/path/hostkey.ser", config.getHostKeyPath());
41+
assertFalse(config.isEnabled());
42+
assertEquals(200, config.getMaxConcurrentSessions());
43+
}
44+
45+
@Test
46+
void testConnectionConfiguration() {
47+
SshProxyConfig config = new SshProxyConfig();
48+
SshProxyConfig.Connection connection = config.getConnection();
49+
50+
connection.setConnectionTimeout(45000);
51+
connection.setKeepAliveInterval(90000);
52+
connection.setMaxRetries(5);
53+
54+
assertEquals(45000, connection.getConnectionTimeout());
55+
assertEquals(90000, connection.getKeepAliveInterval());
56+
assertEquals(5, connection.getMaxRetries());
57+
}
58+
59+
@Test
60+
void testConnectionSubclass() {
61+
SshProxyConfig.Connection connection = new SshProxyConfig.Connection();
62+
63+
// Test default values
64+
assertEquals(30000, connection.getConnectionTimeout());
65+
assertEquals(60000, connection.getKeepAliveInterval());
66+
assertEquals(3, connection.getMaxRetries());
67+
68+
// Test setters
69+
connection.setConnectionTimeout(15000);
70+
connection.setKeepAliveInterval(30000);
71+
connection.setMaxRetries(1);
72+
73+
assertEquals(15000, connection.getConnectionTimeout());
74+
assertEquals(30000, connection.getKeepAliveInterval());
75+
assertEquals(1, connection.getMaxRetries());
76+
}
77+
78+
@Test
79+
void testConfigurationEquality() {
80+
SshProxyConfig config1 = new SshProxyConfig();
81+
SshProxyConfig config2 = new SshProxyConfig();
82+
83+
// Initially both should have same default values
84+
assertEquals(config1.getPort(), config2.getPort());
85+
assertEquals(config1.getHostKeyPath(), config2.getHostKeyPath());
86+
assertEquals(config1.isEnabled(), config2.isEnabled());
87+
assertEquals(config1.getMaxConcurrentSessions(), config2.getMaxConcurrentSessions());
88+
89+
// Change one and verify they're different
90+
config1.setPort(3333);
91+
assertNotEquals(config1.getPort(), config2.getPort());
92+
}
93+
94+
@Test
95+
void testConnectionEquality() {
96+
SshProxyConfig.Connection conn1 = new SshProxyConfig.Connection();
97+
SshProxyConfig.Connection conn2 = new SshProxyConfig.Connection();
98+
99+
// Initially both should have same default values
100+
assertEquals(conn1.getConnectionTimeout(), conn2.getConnectionTimeout());
101+
assertEquals(conn1.getKeepAliveInterval(), conn2.getKeepAliveInterval());
102+
assertEquals(conn1.getMaxRetries(), conn2.getMaxRetries());
103+
104+
// Change one and verify they're different
105+
conn1.setConnectionTimeout(99999);
106+
assertNotEquals(conn1.getConnectionTimeout(), conn2.getConnectionTimeout());
107+
}
108+
}

0 commit comments

Comments
 (0)