Skip to content

Commit a0b55e7

Browse files
authored
Add "ssh server" command (#3475)
## Changes 1. "ssh" and "ssh setup" commands PR: #3470 2. "ssh connect" PR: #3471 This PR is based on the above, it adds "ssh server" command and related logic and finishes the "ssh" implementation by adding `ssh` command group to the root `cmd` (it's still hidden from the --help output). Server overview: - The server itself is started by the client running the `ssh-server-bootsrap.py` job - The client passes inputs to this file as jobs API args, and the python logic gets the values using dbutils widgets - One interesting part is `setup_subreaper` method, which ensures no child process leave the parent process group. This is necessary in order to prevent background processes being reparented to pid 1 when their parent ssh server session terminates, after which they loose access to wsfs and dbfs fuse mounts (as access is based on process groups). All major IDEs usually try to spawn their server processes as a background processes (by double forking for example). - Server creates one sshd config for all future sshd processes in the home folder - It spawns sshd with `-i` flag, making it run as in non-daemon mode and communicating with it through stdin and our (and so sshd doesn't listen on any ports) - Separate sshd process is spawned for each new websocket connection (but the sshd config is re-used), there are limit of 10 connections (for no particular reason) - The server start a shutdown timer when there are not active connections, after which it terminates itself, which terminates the job, and the cluster idle timeout can start ticking after that - The server also places jupyter-init.py file to the `$HOME/.ipython/profile_default/startup`. This file makes jupyter kernels a bit more similar to webapp notebooks - it provides spark and dbutils globals, plus handles %sql and %pip magic commands WIP: unit tests ## Tests Manual and unit
1 parent edeca76 commit a0b55e7

File tree

10 files changed

+865
-2
lines changed

10 files changed

+865
-2
lines changed

cmd/cmd.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"strings"
66

77
"github.com/databricks/cli/cmd/psql"
8+
"github.com/databricks/cli/cmd/ssh"
89

910
"github.com/databricks/cli/cmd/account"
1011
"github.com/databricks/cli/cmd/api"
@@ -77,6 +78,7 @@ func New(ctx context.Context) *cobra.Command {
7778
cli.AddCommand(version.New())
7879
cli.AddCommand(selftest.New())
7980
cli.AddCommand(pipelines.InstallPipelinesCLI())
81+
cli.AddCommand(ssh.New())
8082

8183
// Add workspace command groups, filtering out empty groups or groups with only hidden commands.
8284
allGroups := workspace.Groups()

cmd/ssh/server.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package ssh
2+
3+
import (
4+
"time"
5+
6+
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/libs/cmdctx"
8+
"github.com/databricks/cli/libs/ssh"
9+
"github.com/spf13/cobra"
10+
)
11+
12+
func newServerCommand() *cobra.Command {
13+
cmd := &cobra.Command{
14+
Use: "server",
15+
Short: "Run SSH tunnel server",
16+
Long: `Run SSH tunnel server.
17+
18+
This command starts an SSH tunnel server that accepts WebSocket connections
19+
and proxies them to local SSH daemon processes.
20+
21+
` + disclaimer,
22+
// This is an internal command spawned by the SSH client running the "ssh-server-bootstrap.py" job
23+
Hidden: true,
24+
}
25+
26+
var maxClients int
27+
var shutdownDelay time.Duration
28+
var clusterID string
29+
var version string
30+
31+
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
32+
cmd.MarkFlagRequired("cluster")
33+
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients")
34+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients")
35+
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
36+
37+
cmd.PreRunE = root.MustWorkspaceClient
38+
cmd.RunE = func(cmd *cobra.Command, args []string) error {
39+
ctx := cmd.Context()
40+
client := cmdctx.WorkspaceClient(ctx)
41+
opts := ssh.ServerOptions{
42+
ClusterID: clusterID,
43+
MaxClients: maxClients,
44+
ShutdownDelay: shutdownDelay,
45+
Version: version,
46+
ConfigDir: ".ssh-tunnel",
47+
ServerPrivateKeyName: "server-private-key",
48+
ServerPublicKeyName: "server-public-key",
49+
DefaultPort: 7772,
50+
PortRange: 100,
51+
}
52+
err := ssh.RunServer(ctx, client, opts)
53+
if err != nil && ssh.IsNormalClosure(err) {
54+
return nil
55+
}
56+
return err
57+
}
58+
59+
return cmd
60+
}

cmd/ssh/ssh.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Common workflows:
2929

3030
cmd.AddCommand(newSetupCommand())
3131
cmd.AddCommand(newConnectCommand())
32+
cmd.AddCommand(newServerCommand())
3233

3334
return cmd
3435
}

libs/ssh/jupyter-init.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import List, Optional
2+
from IPython.core.getipython import get_ipython
3+
from IPython.display import display as ip_display
4+
from dbruntime import UserNamespaceInitializer
5+
6+
7+
def _log_exceptions(func):
8+
from functools import wraps
9+
10+
@wraps(func)
11+
def wrapper(*args, **kwargs):
12+
try:
13+
print(f"Executing {func.__name__}")
14+
return func(*args, **kwargs)
15+
except Exception as e:
16+
print(f"Error in {func.__name__}: {e}")
17+
18+
return wrapper
19+
20+
21+
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
22+
_entry_point = _user_namespace_initializer.get_spark_entry_point()
23+
_globals = _user_namespace_initializer.get_namespace_globals()
24+
for name, value in _globals.items():
25+
print(f"Registering global: {name} = {value}")
26+
if name not in globals():
27+
globals()[name] = value
28+
29+
30+
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
31+
# We use the IPython display instead (in combination with the html formatter for DataFrames).
32+
globals()["display"] = ip_display
33+
34+
35+
@_log_exceptions
36+
def _register_runtime_hooks():
37+
from dbruntime.monkey_patches import apply_dataframe_display_patch
38+
from dbruntime.IPythonShellHooks import load_ipython_hooks, IPythonShellHook
39+
from IPython.core.interactiveshell import ExecutionInfo
40+
41+
# Setting executing_raw_cell before cell execution is required to make dbutils.library.restartPython() work
42+
class PreRunHook(IPythonShellHook):
43+
def pre_run_cell(self, info: ExecutionInfo) -> None:
44+
get_ipython().executing_raw_cell = info.raw_cell
45+
46+
load_ipython_hooks(get_ipython(), PreRunHook())
47+
apply_dataframe_display_patch(ip_display)
48+
49+
50+
def _warn_for_dbr_alternative(magic: str):
51+
import warnings
52+
53+
"""Warn users about magics that have Databricks alternatives."""
54+
local_magic_dbr_alternative = {"%%sh": "%sh"}
55+
if magic in local_magic_dbr_alternative:
56+
warnings.warn(
57+
f"\\n{magic} is not supported on Databricks. This notebook might fail when running on a Databricks cluster.\\n"
58+
f"Consider using %{local_magic_dbr_alternative[magic]} instead."
59+
)
60+
61+
62+
def _throw_if_not_supported(magic: str):
63+
"""Throw an error for magics that are not supported locally."""
64+
unsupported_dbr_magics = ["%r", "%scala"]
65+
if magic in unsupported_dbr_magics:
66+
raise NotImplementedError(f"{magic} is not supported for local Databricks Notebooks.")
67+
68+
69+
def _get_cell_magic(lines: List[str]) -> Optional[str]:
70+
"""Extract cell magic from the first line if it exists."""
71+
if len(lines) == 0:
72+
return None
73+
if lines[0].strip().startswith("%%"):
74+
return lines[0].split(" ")[0].strip()
75+
return None
76+
77+
78+
def _get_line_magic(lines: List[str]) -> Optional[str]:
79+
"""Extract line magic from the first line if it exists."""
80+
if len(lines) == 0:
81+
return None
82+
if lines[0].strip().startswith("%"):
83+
return lines[0].split(" ")[0].strip().strip("%")
84+
return None
85+
86+
87+
def _handle_cell_magic(lines: List[str]) -> List[str]:
88+
"""Process cell magic commands."""
89+
cell_magic = _get_cell_magic(lines)
90+
if cell_magic is None:
91+
return lines
92+
93+
_warn_for_dbr_alternative(cell_magic)
94+
_throw_if_not_supported(cell_magic)
95+
return lines
96+
97+
98+
def _handle_line_magic(lines: List[str]) -> List[str]:
99+
"""Process line magic commands and transform them appropriately."""
100+
lmagic = _get_line_magic(lines)
101+
if lmagic is None:
102+
return lines
103+
104+
_warn_for_dbr_alternative(lmagic)
105+
_throw_if_not_supported(lmagic)
106+
107+
if lmagic in ["md", "md-sandbox"]:
108+
lines[0] = "%%markdown" + lines[0].partition("%" + lmagic)[2]
109+
return lines
110+
111+
if lmagic == "sh":
112+
lines[0] = "%%sh" + lines[0].partition("%" + lmagic)[2]
113+
return lines
114+
115+
if lmagic == "sql":
116+
lines = lines[1:]
117+
spark_string = "global _sqldf\n" + "_sqldf = spark.sql('''" + "".join(lines).replace("'", "\\'") + "''')\n" + "display(_sqldf)\n"
118+
return spark_string.splitlines(keepends=True)
119+
120+
if lmagic == "python":
121+
return lines[1:]
122+
123+
return lines
124+
125+
126+
def _strip_hash_magic(lines: List[str]) -> List[str]:
127+
if len(lines) == 0:
128+
return lines
129+
if lines[0].startswith("# MAGIC"):
130+
return [line.partition("# MAGIC ")[2] for line in lines]
131+
return lines
132+
133+
134+
def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:
135+
"""Main parser function for Databricks magic commands."""
136+
if len(lines) == 0:
137+
return lines
138+
139+
lines_to_ignore = ("# Databricks notebook source", "# COMMAND ----------", "# DBTITLE")
140+
lines = [line for line in lines if not line.strip().startswith(lines_to_ignore)]
141+
lines = "".join(lines).strip().splitlines(keepends=True)
142+
lines = _strip_hash_magic(lines)
143+
144+
if _get_cell_magic(lines):
145+
return _handle_cell_magic(lines)
146+
147+
if _get_line_magic(lines):
148+
return _handle_line_magic(lines)
149+
150+
return lines
151+
152+
153+
@_log_exceptions
154+
def _register_magics():
155+
"""Register the magic command parser with IPython."""
156+
from dbruntime.DatasetInfo import UserNamespaceDict
157+
from dbruntime.PipMagicOverrides import PipMagicOverrides
158+
159+
user_ns = UserNamespaceDict(
160+
_user_namespace_initializer.get_namespace_globals(),
161+
_entry_point.getDriverConf(),
162+
_entry_point,
163+
)
164+
ip = get_ipython()
165+
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
166+
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns))
167+
168+
169+
@_log_exceptions
170+
def _register_formatters():
171+
from pyspark.sql import DataFrame
172+
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataframe
173+
174+
def df_html(df: DataFrame) -> str:
175+
return df.toPandas().to_html()
176+
177+
ip = get_ipython()
178+
html_formatter = ip.display_formatter.formatters["text/html"]
179+
html_formatter.for_type(SparkConnectDataframe, df_html)
180+
html_formatter.for_type(DataFrame, df_html)
181+
182+
183+
_register_magics()
184+
_register_formatters()
185+
_register_runtime_hooks()

libs/ssh/keys.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414

1515
"github.com/databricks/cli/libs/cmdio"
16+
"github.com/databricks/databricks-sdk-go"
1617
"golang.org/x/crypto/ssh"
1718
)
1819

@@ -128,3 +129,41 @@ func checkAndGenerateSSHKeyPair(ctx context.Context, keyPath string) (string, st
128129

129130
return keyPath, strings.TrimSpace(string(publicKeyBytes)), nil
130131
}
132+
133+
func checkAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, privateKeyName, publicKeyName string) ([]byte, []byte, error) {
134+
secretsScopeName, err := createSecretsScope(ctx, client, clusterID)
135+
if err != nil {
136+
return nil, nil, fmt.Errorf("failed to create secrets scope: %w", err)
137+
}
138+
139+
privateKeyBytes, err := getSecret(ctx, client, secretsScopeName, privateKeyName)
140+
if err != nil {
141+
cmdio.LogString(ctx, "SSH key pair not found in secrets scope, generating a new one...")
142+
143+
privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair()
144+
if err != nil {
145+
return nil, nil, fmt.Errorf("failed to generate SSH key pair: %w", err)
146+
}
147+
148+
err = putSecret(ctx, client, secretsScopeName, privateKeyName, string(privateKeyBytes))
149+
if err != nil {
150+
return nil, nil, err
151+
}
152+
153+
err = putSecret(ctx, client, secretsScopeName, publicKeyName, string(publicKeyBytes))
154+
if err != nil {
155+
return nil, nil, err
156+
}
157+
158+
return privateKeyBytes, publicKeyBytes, nil
159+
} else {
160+
cmdio.LogString(ctx, "Using SSH key pair from secrets scope")
161+
162+
publicKeyBytes, err := getSecret(ctx, client, secretsScopeName, publicKeyName)
163+
if err != nil {
164+
return nil, nil, fmt.Errorf("failed to get public key from secrets scope: %w", err)
165+
}
166+
167+
return privateKeyBytes, publicKeyBytes, nil
168+
}
169+
}

libs/ssh/proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (pc *proxyConnection) Close() error {
180180
// Keep in mind that pc.sendMessage blocks during handover
181181
err := pc.sendMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
182182
if err != nil {
183-
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
183+
if IsNormalClosure(err) {
184184
return nil
185185
} else {
186186
return fmt.Errorf("failed to send close message: %w", err)

libs/ssh/secrets.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh
22

33
import (
44
"context"
5+
"encoding/base64"
56
"errors"
67
"fmt"
78

@@ -24,6 +25,22 @@ func createSecretsScope(ctx context.Context, client *databricks.WorkspaceClient,
2425
return secretsScope, nil
2526
}
2627

28+
func getSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key string) ([]byte, error) {
29+
resp, err := client.Secrets.GetSecret(ctx, workspace.GetSecretRequest{
30+
Scope: scope,
31+
Key: key,
32+
})
33+
if err != nil {
34+
return nil, fmt.Errorf("failed to get secret %s from scope %s: %w", key, scope, err)
35+
}
36+
37+
value, err := base64.StdEncoding.DecodeString(resp.Value)
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to decode secret key from base64: %w", err)
40+
}
41+
return value, nil
42+
}
43+
2744
func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key, value string) error {
2845
err := client.Secrets.PutSecret(ctx, workspace.PutSecret{
2946
Scope: scope,

0 commit comments

Comments
 (0)