Skip to content

Commit d95fc1d

Browse files
committed
Add authorized_keys manager
1 parent 2be4d38 commit d95fc1d

File tree

4 files changed

+267
-18
lines changed

4 files changed

+267
-18
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"strings"
8+
"sync"
9+
10+
"github.com/databricks/cli/experimental/ssh/internal/keys"
11+
"github.com/databricks/cli/libs/log"
12+
"github.com/databricks/databricks-sdk-go"
13+
)
14+
15+
type AuthorizedKeysManager struct {
16+
mu sync.Mutex
17+
filePath string
18+
client *databricks.WorkspaceClient
19+
secretScope string
20+
addedKeys map[string]bool
21+
}
22+
23+
func NewAuthorizedKeysManager(client *databricks.WorkspaceClient, filePath, secretScope string) *AuthorizedKeysManager {
24+
return &AuthorizedKeysManager{
25+
filePath: filePath,
26+
client: client,
27+
secretScope: secretScope,
28+
addedKeys: make(map[string]bool),
29+
}
30+
}
31+
32+
// Adds a public key from secrets scope to the authorized_keys file.
33+
// If the key has already been added, this is a no-op.
34+
func (akm *AuthorizedKeysManager) AddKey(ctx context.Context, publicKeyName string) error {
35+
akm.mu.Lock()
36+
defer akm.mu.Unlock()
37+
38+
if akm.addedKeys[publicKeyName] {
39+
log.Infof(ctx, "Public key %s already added, skipping", publicKeyName)
40+
return nil
41+
}
42+
43+
log.Infof(ctx, "Adding public key from secret: %s", publicKeyName)
44+
clientPublicKey, err := keys.GetSecret(ctx, akm.client, akm.secretScope, publicKeyName)
45+
if err != nil {
46+
return fmt.Errorf("failed to get client public key: %w", err)
47+
}
48+
49+
authKeys, err := os.OpenFile(akm.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
50+
if err != nil {
51+
return fmt.Errorf("failed to open authorized keys file: %w", err)
52+
}
53+
defer authKeys.Close()
54+
55+
content := strings.TrimSpace(string(clientPublicKey))
56+
_, err = authKeys.WriteString("\n" + content)
57+
if err != nil {
58+
return err
59+
}
60+
61+
akm.addedKeys[publicKeyName] = true
62+
return nil
63+
}
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"errors"
7+
"fmt"
8+
"os"
9+
"path/filepath"
10+
"strings"
11+
"sync"
12+
"testing"
13+
14+
"github.com/databricks/databricks-sdk-go/experimental/mocks"
15+
mockWorkspace "github.com/databricks/databricks-sdk-go/experimental/mocks/service/workspace"
16+
"github.com/databricks/databricks-sdk-go/service/workspace"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
type testSetup struct {
22+
ctx context.Context
23+
authKeysPath string
24+
mockClient *mocks.MockWorkspaceClient
25+
secretsAPI *mockWorkspace.MockSecretsInterface
26+
manager *AuthorizedKeysManager
27+
}
28+
29+
func setupTest(t *testing.T) *testSetup {
30+
ctx := context.Background()
31+
tempDir := t.TempDir()
32+
authKeysPath := filepath.Join(tempDir, "authorized_keys")
33+
34+
m := mocks.NewMockWorkspaceClient(t)
35+
secretsAPI := m.GetMockSecretsAPI()
36+
manager := NewAuthorizedKeysManager(m.WorkspaceClient, authKeysPath, "test-scope")
37+
38+
return &testSetup{
39+
ctx: ctx,
40+
authKeysPath: authKeysPath,
41+
mockClient: m,
42+
secretsAPI: secretsAPI,
43+
manager: manager,
44+
}
45+
}
46+
47+
func (s *testSetup) mockGetSecret(keyName, publicKey string) {
48+
encodedKey := base64.StdEncoding.EncodeToString([]byte(publicKey))
49+
s.secretsAPI.EXPECT().GetSecret(s.ctx, workspace.GetSecretRequest{
50+
Scope: "test-scope",
51+
Key: keyName,
52+
}).Return(&workspace.GetSecretResponse{
53+
Value: encodedKey,
54+
}, nil)
55+
}
56+
57+
func (s *testSetup) mockGetSecretOnce(keyName, publicKey string) {
58+
encodedKey := base64.StdEncoding.EncodeToString([]byte(publicKey))
59+
s.secretsAPI.EXPECT().GetSecret(s.ctx, workspace.GetSecretRequest{
60+
Scope: "test-scope",
61+
Key: keyName,
62+
}).Return(&workspace.GetSecretResponse{
63+
Value: encodedKey,
64+
}, nil).Once()
65+
}
66+
67+
func (s *testSetup) mockGetSecretError(keyName string, err error) {
68+
s.secretsAPI.EXPECT().GetSecret(s.ctx, workspace.GetSecretRequest{
69+
Scope: "test-scope",
70+
Key: keyName,
71+
}).Return(nil, err)
72+
}
73+
74+
func TestAuthorizedKeysManager_AddKey_Success(t *testing.T) {
75+
s := setupTest(t)
76+
publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC... [email protected]"
77+
78+
s.mockGetSecret("test-key", publicKey)
79+
80+
err := s.manager.AddKey(s.ctx, "test-key")
81+
require.NoError(t, err)
82+
83+
content, err := os.ReadFile(s.authKeysPath)
84+
require.NoError(t, err)
85+
assert.Contains(t, string(content), publicKey)
86+
assert.True(t, s.manager.addedKeys["test-key"])
87+
}
88+
89+
func TestAuthorizedKeysManager_AddKey_Deduplication(t *testing.T) {
90+
s := setupTest(t)
91+
publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC... [email protected]"
92+
93+
s.mockGetSecretOnce("test-key", publicKey)
94+
95+
err := s.manager.AddKey(s.ctx, "test-key")
96+
require.NoError(t, err)
97+
98+
contentAfterFirst, err := os.ReadFile(s.authKeysPath)
99+
require.NoError(t, err)
100+
101+
err = s.manager.AddKey(s.ctx, "test-key")
102+
require.NoError(t, err)
103+
104+
contentAfterSecond, err := os.ReadFile(s.authKeysPath)
105+
require.NoError(t, err)
106+
107+
assert.Equal(t, string(contentAfterFirst), string(contentAfterSecond))
108+
occurrences := strings.Count(string(contentAfterSecond), publicKey)
109+
assert.Equal(t, 1, occurrences)
110+
}
111+
112+
func TestAuthorizedKeysManager_AddKey_GetSecretError(t *testing.T) {
113+
s := setupTest(t)
114+
115+
s.mockGetSecretError("missing-key", errors.New("secret not found"))
116+
117+
err := s.manager.AddKey(s.ctx, "missing-key")
118+
assert.Error(t, err)
119+
assert.Contains(t, err.Error(), "failed to get client public key")
120+
assert.False(t, s.manager.addedKeys["missing-key"])
121+
}
122+
123+
func TestAuthorizedKeysManager_AddKey_FileWriteError(t *testing.T) {
124+
s := setupTest(t)
125+
publicKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC... [email protected]"
126+
127+
s.mockGetSecret("test-key", publicKey)
128+
s.manager = NewAuthorizedKeysManager(s.mockClient.WorkspaceClient, "/nonexistent/directory/authorized_keys", "test-scope")
129+
130+
err := s.manager.AddKey(s.ctx, "test-key")
131+
assert.Error(t, err)
132+
assert.Contains(t, err.Error(), "failed to open authorized keys file")
133+
assert.False(t, s.manager.addedKeys["test-key"])
134+
}
135+
136+
func TestAuthorizedKeysManager_AddKey_ThreadSafety(t *testing.T) {
137+
s := setupTest(t)
138+
const numGoroutines = 10
139+
const numKeysPerGoroutine = 10
140+
141+
expectedKeys := make([]string, 0, numGoroutines*numKeysPerGoroutine)
142+
for i := range numGoroutines {
143+
for j := range numKeysPerGoroutine {
144+
keyName := fmt.Sprintf("key-%d-%d", i, j)
145+
publicKey := fmt.Sprintf("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC%d%d [email protected]", i, j)
146+
s.mockGetSecret(keyName, publicKey)
147+
expectedKeys = append(expectedKeys, publicKey)
148+
}
149+
}
150+
151+
var wg sync.WaitGroup
152+
for i := range numGoroutines {
153+
wg.Go(func() {
154+
for j := range numKeysPerGoroutine {
155+
keyName := fmt.Sprintf("key-%d-%d", i, j)
156+
err := s.manager.AddKey(s.ctx, keyName)
157+
assert.NoError(t, err)
158+
}
159+
})
160+
}
161+
162+
wg.Wait()
163+
164+
assert.Equal(t, numGoroutines*numKeysPerGoroutine, len(s.manager.addedKeys))
165+
166+
content, err := os.ReadFile(s.authKeysPath)
167+
require.NoError(t, err)
168+
169+
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
170+
for _, publicKey := range expectedKeys {
171+
assert.Contains(t, lines, publicKey)
172+
}
173+
}
174+
175+
func TestAuthorizedKeysManager_AddKey_MultipleKeys(t *testing.T) {
176+
s := setupTest(t)
177+
keys := map[string]string{
178+
"key1": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC1 [email protected]",
179+
"key2": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC2 [email protected]",
180+
"key3": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC3 [email protected]",
181+
}
182+
183+
for keyName, publicKey := range keys {
184+
s.mockGetSecret(keyName, publicKey)
185+
}
186+
187+
for keyName := range keys {
188+
err := s.manager.AddKey(s.ctx, keyName)
189+
require.NoError(t, err)
190+
}
191+
192+
content, err := os.ReadFile(s.authKeysPath)
193+
require.NoError(t, err)
194+
195+
for _, publicKey := range keys {
196+
assert.Contains(t, string(content), publicKey)
197+
}
198+
199+
assert.Equal(t, len(keys), len(s.manager.addedKeys))
200+
}

experimental/ssh/internal/server/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt
6969
return fmt.Errorf("failed to save Jupyter init script: %w", err)
7070
}
7171

72-
connections := proxy.NewConnectionsManager(opts.MaxClients, opts.ShutdownDelay)
72+
authKeysManager := NewAuthorizedKeysManager(client, authKeysPath, opts.SecretScopeName)
7373
createServerCommand := func(ctx context.Context, publicKeyName string) (*exec.Cmd, error) {
74-
err := updateAuthorizedKeys(ctx, client, authKeysPath, opts.SecretScopeName, publicKeyName)
74+
err := authKeysManager.AddKey(ctx, publicKeyName)
7575
if err != nil {
7676
return nil, fmt.Errorf("failed to store auth key: %w", err)
7777
}
7878
return createSSHDProcess(ctx, sshdConfigPath), nil
7979
}
80+
connections := proxy.NewConnectionsManager(opts.MaxClients, opts.ShutdownDelay)
8081
http.Handle("/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand))
8182
http.HandleFunc("/metadata", serveMetadata)
8283
go handleTimeout(ctx, connections.TimedOut, opts.ShutdownDelay)

experimental/ssh/internal/server/sshd.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient,
4343

4444
sshdConfig := filepath.Join(sshDir, "sshd_config")
4545
authKeysPath := filepath.Join(sshDir, "authorized_keys")
46+
// Prepare an empty authorized_keys file, it will be updated each time a new client connects
4647
if err := os.WriteFile(authKeysPath, []byte(""), 0o600); err != nil {
4748
return "", "", err
4849
}
@@ -86,22 +87,6 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient,
8687
return sshdConfig, authKeysPath, nil
8788
}
8889

89-
func updateAuthorizedKeys(ctx context.Context, client *databricks.WorkspaceClient, authKeysPath, secretScopeName, publicKeyName string) error {
90-
log.Info(ctx, "Using public key secret name:"+publicKeyName)
91-
clientPublicKey, err := keys.GetSecret(ctx, client, secretScopeName, publicKeyName)
92-
if err != nil {
93-
return fmt.Errorf("failed to get client public key: %w", err)
94-
}
95-
authKeys, err := os.OpenFile(authKeysPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
96-
if err != nil {
97-
return fmt.Errorf("failed to open authorized keys file: %w", err)
98-
}
99-
defer authKeys.Close()
100-
content := strings.TrimSpace(string(clientPublicKey))
101-
_, err = authKeys.WriteString("\n" + content)
102-
return err
103-
}
104-
10590
func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd {
10691
return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i")
10792
}

0 commit comments

Comments
 (0)