Skip to content

Commit 1508d65

Browse files
authored
Extract functionality to detect if the CLI is running on DBR (#1889)
## Changes Whether or not the CLI is running on DBR can be detected once and stored in the command's context. By storing it in the context, it can easily be mocked for testing. This builds on the simpler approach and conversation in #1744. It unblocks testing of the DBR-specific paths while not compromising on the checks we can perform to test if the CLI is running on DBR. ## Tests * Unit tests for the new `dbr` package * New unit test for the `ConfigureWSFS` mutator
1 parent 2edfb6c commit 1508d65

File tree

8 files changed

+352
-4
lines changed

8 files changed

+352
-4
lines changed

bundle/config/mutator/configure_wsfs.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ import (
55
"strings"
66

77
"github.com/databricks/cli/bundle"
8+
"github.com/databricks/cli/libs/dbr"
89
"github.com/databricks/cli/libs/diag"
9-
"github.com/databricks/cli/libs/env"
1010
"github.com/databricks/cli/libs/filer"
1111
"github.com/databricks/cli/libs/vfs"
1212
)
1313

14-
const envDatabricksRuntimeVersion = "DATABRICKS_RUNTIME_VERSION"
15-
1614
type configureWSFS struct{}
1715

1816
func ConfigureWSFS() bundle.Mutator {
@@ -32,7 +30,7 @@ func (m *configureWSFS) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagno
3230
}
3331

3432
// The executable must be running on DBR.
35-
if _, ok := env.Lookup(ctx, envDatabricksRuntimeVersion); !ok {
33+
if !dbr.RunsOnRuntime(ctx) {
3634
return nil
3735
}
3836

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package mutator_test
2+
3+
import (
4+
"context"
5+
"runtime"
6+
"testing"
7+
8+
"github.com/databricks/cli/bundle"
9+
"github.com/databricks/cli/bundle/config/mutator"
10+
"github.com/databricks/cli/libs/dbr"
11+
"github.com/databricks/cli/libs/vfs"
12+
"github.com/databricks/databricks-sdk-go/config"
13+
"github.com/databricks/databricks-sdk-go/experimental/mocks"
14+
"github.com/stretchr/testify/assert"
15+
)
16+
17+
func mockBundleForConfigureWSFS(t *testing.T, syncRootPath string) *bundle.Bundle {
18+
// The native path of the sync root on Windows will never match the /Workspace prefix,
19+
// so the test case for nominal behavior will always fail.
20+
if runtime.GOOS == "windows" {
21+
t.Skip("this test is not applicable on Windows")
22+
}
23+
24+
b := &bundle.Bundle{
25+
SyncRoot: vfs.MustNew(syncRootPath),
26+
}
27+
28+
w := mocks.NewMockWorkspaceClient(t)
29+
w.WorkspaceClient.Config = &config.Config{}
30+
b.SetWorkpaceClient(w.WorkspaceClient)
31+
32+
return b
33+
}
34+
35+
func TestConfigureWSFS_SkipsIfNotWorkspacePrefix(t *testing.T) {
36+
b := mockBundleForConfigureWSFS(t, "/foo")
37+
originalSyncRoot := b.SyncRoot
38+
39+
ctx := context.Background()
40+
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
41+
assert.Empty(t, diags)
42+
assert.Equal(t, originalSyncRoot, b.SyncRoot)
43+
}
44+
45+
func TestConfigureWSFS_SkipsIfNotRunningOnRuntime(t *testing.T) {
46+
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
47+
originalSyncRoot := b.SyncRoot
48+
49+
ctx := context.Background()
50+
ctx = dbr.MockRuntime(ctx, false)
51+
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
52+
assert.Empty(t, diags)
53+
assert.Equal(t, originalSyncRoot, b.SyncRoot)
54+
}
55+
56+
func TestConfigureWSFS_SwapSyncRoot(t *testing.T) {
57+
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
58+
originalSyncRoot := b.SyncRoot
59+
60+
ctx := context.Background()
61+
ctx = dbr.MockRuntime(ctx, true)
62+
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
63+
assert.Empty(t, diags)
64+
assert.NotEqual(t, originalSyncRoot, b.SyncRoot)
65+
}

cmd/root/root.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/databricks/cli/internal/build"
1313
"github.com/databricks/cli/libs/cmdio"
14+
"github.com/databricks/cli/libs/dbr"
1415
"github.com/databricks/cli/libs/log"
1516
"github.com/spf13/cobra"
1617
)
@@ -73,6 +74,9 @@ func New(ctx context.Context) *cobra.Command {
7374
// get the context back
7475
ctx = cmd.Context()
7576

77+
// Detect if the CLI is running on DBR and store this on the context.
78+
ctx = dbr.DetectRuntime(ctx)
79+
7680
// Configure our user agent with the command that's about to be executed.
7781
ctx = withCommandInUserAgent(ctx, cmd)
7882
ctx = withCommandExecIdInUserAgent(ctx)

libs/dbr/context.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package dbr
2+
3+
import "context"
4+
5+
// key is a package-local type to use for context keys.
6+
//
7+
// Using an unexported type for context keys prevents key collisions across
8+
// packages since external packages cannot create values of this type.
9+
type key int
10+
11+
const (
12+
// dbrKey is the context key for the detection result.
13+
// The value of 1 is arbitrary and can be any number.
14+
// Other keys in the same package must have different values.
15+
dbrKey = key(1)
16+
)
17+
18+
// DetectRuntime detects whether or not the current
19+
// process is running inside a Databricks Runtime environment.
20+
// It return a new context with the detection result set.
21+
func DetectRuntime(ctx context.Context) context.Context {
22+
if v := ctx.Value(dbrKey); v != nil {
23+
panic("dbr.DetectRuntime called twice on the same context")
24+
}
25+
return context.WithValue(ctx, dbrKey, detect(ctx))
26+
}
27+
28+
// MockRuntime is a helper function to mock the detection result.
29+
// It returns a new context with the detection result set.
30+
func MockRuntime(ctx context.Context, b bool) context.Context {
31+
if v := ctx.Value(dbrKey); v != nil {
32+
panic("dbr.MockRuntime called twice on the same context")
33+
}
34+
return context.WithValue(ctx, dbrKey, b)
35+
}
36+
37+
// RunsOnRuntime returns the detection result from the context.
38+
// It expects a context returned by [DetectRuntime] or [MockRuntime].
39+
//
40+
// We store this value in a context to avoid having to use either
41+
// a global variable, passing a boolean around everywhere, or
42+
// performing the same detection multiple times.
43+
func RunsOnRuntime(ctx context.Context) bool {
44+
v := ctx.Value(dbrKey)
45+
if v == nil {
46+
panic("dbr.RunsOnRuntime called without calling dbr.DetectRuntime first")
47+
}
48+
return v.(bool)
49+
}

libs/dbr/context_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package dbr
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestContext_DetectRuntimePanics(t *testing.T) {
11+
ctx := context.Background()
12+
13+
// Run detection.
14+
ctx = DetectRuntime(ctx)
15+
16+
// Expect a panic if the detection is run twice.
17+
assert.Panics(t, func() {
18+
ctx = DetectRuntime(ctx)
19+
})
20+
}
21+
22+
func TestContext_MockRuntimePanics(t *testing.T) {
23+
ctx := context.Background()
24+
25+
// Run detection.
26+
ctx = MockRuntime(ctx, true)
27+
28+
// Expect a panic if the mock function is run twice.
29+
assert.Panics(t, func() {
30+
MockRuntime(ctx, true)
31+
})
32+
}
33+
34+
func TestContext_RunsOnRuntimePanics(t *testing.T) {
35+
ctx := context.Background()
36+
37+
// Expect a panic if the detection is not run.
38+
assert.Panics(t, func() {
39+
RunsOnRuntime(ctx)
40+
})
41+
}
42+
43+
func TestContext_RunsOnRuntime(t *testing.T) {
44+
ctx := context.Background()
45+
46+
// Run detection.
47+
ctx = DetectRuntime(ctx)
48+
49+
// Expect no panic because detection has run.
50+
assert.NotPanics(t, func() {
51+
RunsOnRuntime(ctx)
52+
})
53+
}
54+
55+
func TestContext_RunsOnRuntimeWithMock(t *testing.T) {
56+
ctx := context.Background()
57+
assert.True(t, RunsOnRuntime(MockRuntime(ctx, true)))
58+
assert.False(t, RunsOnRuntime(MockRuntime(ctx, false)))
59+
}

libs/dbr/detect.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package dbr
2+
3+
import (
4+
"context"
5+
"os"
6+
"runtime"
7+
8+
"github.com/databricks/cli/libs/env"
9+
)
10+
11+
// Dereference [os.Stat] to allow mocking in tests.
12+
var statFunc = os.Stat
13+
14+
// detect returns true if the current process is running on a Databricks Runtime.
15+
// Its return value is meant to be cached in the context.
16+
func detect(ctx context.Context) bool {
17+
// Databricks Runtime implies Linux.
18+
// Return early on other operating systems.
19+
if runtime.GOOS != "linux" {
20+
return false
21+
}
22+
23+
// Databricks Runtime always has the DATABRICKS_RUNTIME_VERSION environment variable set.
24+
if value, ok := env.Lookup(ctx, "DATABRICKS_RUNTIME_VERSION"); !ok || value == "" {
25+
return false
26+
}
27+
28+
// Expect to see a "/databricks" directory.
29+
if fi, err := statFunc("/databricks"); err != nil || !fi.IsDir() {
30+
return false
31+
}
32+
33+
// All checks passed.
34+
return true
35+
}

libs/dbr/detect_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package dbr
2+
3+
import (
4+
"context"
5+
"io/fs"
6+
"runtime"
7+
"testing"
8+
9+
"github.com/databricks/cli/libs/env"
10+
"github.com/databricks/cli/libs/fakefs"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func requireLinux(t *testing.T) {
15+
if runtime.GOOS != "linux" {
16+
t.Skipf("skipping test on %s", runtime.GOOS)
17+
}
18+
}
19+
20+
func configureStatFunc(t *testing.T, fi fs.FileInfo, err error) {
21+
originalFunc := statFunc
22+
statFunc = func(name string) (fs.FileInfo, error) {
23+
assert.Equal(t, "/databricks", name)
24+
return fi, err
25+
}
26+
27+
t.Cleanup(func() {
28+
statFunc = originalFunc
29+
})
30+
}
31+
32+
func TestDetect_NotLinux(t *testing.T) {
33+
if runtime.GOOS == "linux" {
34+
t.Skip("skipping test on Linux OS")
35+
}
36+
37+
ctx := context.Background()
38+
assert.False(t, detect(ctx))
39+
}
40+
41+
func TestDetect_Env(t *testing.T) {
42+
requireLinux(t)
43+
44+
// Configure other checks to pass.
45+
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)
46+
47+
t.Run("empty", func(t *testing.T) {
48+
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "")
49+
assert.False(t, detect(ctx))
50+
})
51+
52+
t.Run("non-empty cluster", func(t *testing.T) {
53+
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "15.4")
54+
assert.True(t, detect(ctx))
55+
})
56+
57+
t.Run("non-empty serverless", func(t *testing.T) {
58+
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "client.1.13")
59+
assert.True(t, detect(ctx))
60+
})
61+
}
62+
63+
func TestDetect_Stat(t *testing.T) {
64+
requireLinux(t)
65+
66+
// Configure other checks to pass.
67+
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "non-empty")
68+
69+
t.Run("error", func(t *testing.T) {
70+
configureStatFunc(t, nil, fs.ErrNotExist)
71+
assert.False(t, detect(ctx))
72+
})
73+
74+
t.Run("not a directory", func(t *testing.T) {
75+
configureStatFunc(t, fakefs.FileInfo{}, nil)
76+
assert.False(t, detect(ctx))
77+
})
78+
79+
t.Run("directory", func(t *testing.T) {
80+
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)
81+
assert.True(t, detect(ctx))
82+
})
83+
}

libs/fakefs/fakefs.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package fakefs
2+
3+
import (
4+
"io/fs"
5+
"time"
6+
)
7+
8+
// DirEntry is a fake implementation of [fs.DirEntry].
9+
type DirEntry struct {
10+
FileInfo
11+
}
12+
13+
func (entry DirEntry) Type() fs.FileMode {
14+
typ := fs.ModePerm
15+
if entry.FakeDir {
16+
typ |= fs.ModeDir
17+
}
18+
return typ
19+
}
20+
21+
func (entry DirEntry) Info() (fs.FileInfo, error) {
22+
return entry.FileInfo, nil
23+
}
24+
25+
// FileInfo is a fake implementation of [fs.FileInfo].
26+
type FileInfo struct {
27+
FakeName string
28+
FakeSize int64
29+
FakeDir bool
30+
FakeMode fs.FileMode
31+
}
32+
33+
func (info FileInfo) Name() string {
34+
return info.FakeName
35+
}
36+
37+
func (info FileInfo) Size() int64 {
38+
return info.FakeSize
39+
}
40+
41+
func (info FileInfo) Mode() fs.FileMode {
42+
return info.FakeMode
43+
}
44+
45+
func (info FileInfo) ModTime() time.Time {
46+
return time.Now()
47+
}
48+
49+
func (info FileInfo) IsDir() bool {
50+
return info.FakeDir
51+
}
52+
53+
func (info FileInfo) Sys() any {
54+
return nil
55+
}

0 commit comments

Comments
 (0)