Skip to content

Commit d37fd28

Browse files
authored
Migrate progress logger functions to use cmdio directly (#3818)
## Changes This continues the progress logger simplification by migrating from the `Logger.Log()` event-based system to direct cmdio calls. This builds on #3811 and #3812, which made the progress logger effectively only write strings to stderr. ## Why This simplifies the codebase by removing the (unused) event abstraction layer while maintaining functionality. The compatibility layer provides a clean migration path and will eventually be removed once the functionality in those functions have a dedicated new home. ## Tests Tests pass. Analysis of the diff suggests that before/after is functionally equivalent. A minor note is that we no longer have a mutex around `.Log()` calls, but there are no concurrent calls to the function.
1 parent 6bdc33e commit d37fd28

File tree

15 files changed

+354
-234
lines changed

15 files changed

+354
-234
lines changed

bundle/run/job.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ func (r *jobRunner) logFailedTasks(ctx context.Context, runId int64) {
7474
log.Errorf(ctx, "task %s failed. Unable to fetch error trace: %s", red(task.TaskKey), err)
7575
continue
7676
}
77-
if progressLogger, ok := cmdio.FromContext(ctx); ok {
78-
progressLogger.Log(progress.NewTaskErrorEvent(task.TaskKey, taskInfo.Error, taskInfo.ErrorTrace))
79-
}
77+
cmdio.Log(ctx, progress.NewTaskErrorEvent(task.TaskKey, taskInfo.Error, taskInfo.ErrorTrace))
8078
log.Errorf(ctx, "Task %s failed!\nError:\n%s\nTrace:\n%s",
8179
red(task.TaskKey), taskInfo.Error, taskInfo.ErrorTrace)
8280
} else {
@@ -89,9 +87,8 @@ func (r *jobRunner) logFailedTasks(ctx context.Context, runId int64) {
8987
// jobRunMonitor tracks state for a single job run and provides callbacks
9088
// for monitoring progress.
9189
type jobRunMonitor struct {
92-
ctx context.Context
93-
prevState *jobs.RunState
94-
progressLogger *cmdio.Logger
90+
ctx context.Context
91+
prevState *jobs.RunState
9592
}
9693

9794
// onProgress is the single callback that handles all state tracking and logging.
@@ -104,7 +101,7 @@ func (m *jobRunMonitor) onProgress(info *jobs.Run) {
104101
// First time we see this run.
105102
if m.prevState == nil {
106103
log.Infof(m.ctx, "Run available at %s", info.RunPageUrl)
107-
m.progressLogger.Log(progress.NewJobRunUrlEvent(info.RunPageUrl))
104+
cmdio.Log(m.ctx, progress.NewJobRunUrlEvent(info.RunPageUrl))
108105
}
109106

110107
// No state change: do not log.
@@ -125,7 +122,7 @@ func (m *jobRunMonitor) onProgress(info *jobs.Run) {
125122
RunName: info.RunName,
126123
State: *info.State,
127124
}
128-
m.progressLogger.Log(event)
125+
cmdio.Log(m.ctx, event)
129126
log.Info(m.ctx, event.String())
130127
}
131128

@@ -151,15 +148,8 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e
151148

152149
w := r.bundle.WorkspaceClient()
153150

154-
// callback to log progress events. Called on every poll request
155-
progressLogger, ok := cmdio.FromContext(ctx)
156-
if !ok {
157-
return nil, errors.New("no progress logger found")
158-
}
159-
160151
monitor := &jobRunMonitor{
161-
ctx: ctx,
162-
progressLogger: progressLogger,
152+
ctx: ctx,
163153
}
164154

165155
waiter, err := w.Jobs.RunNow(ctx, *req)
@@ -171,7 +161,7 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e
171161
details, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{
172162
RunId: waiter.RunId,
173163
})
174-
progressLogger.Log(progress.NewJobRunUrlEvent(details.RunPageUrl))
164+
cmdio.Log(ctx, progress.NewJobRunUrlEvent(details.RunPageUrl))
175165
return nil, err
176166
}
177167

bundle/run/job_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/databricks/cli/bundle/config"
1010
"github.com/databricks/cli/bundle/config/resources"
1111
"github.com/databricks/cli/libs/cmdio"
12-
"github.com/databricks/cli/libs/flags"
1312
"github.com/databricks/databricks-sdk-go/experimental/mocks"
1413
"github.com/databricks/databricks-sdk-go/service/jobs"
1514
"github.com/stretchr/testify/mock"
@@ -160,7 +159,6 @@ func TestJobRunnerRestart(t *testing.T) {
160159
b.SetWorkpaceClient(m.WorkspaceClient)
161160

162161
ctx := cmdio.MockDiscard(context.Background())
163-
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
164162

165163
jobApi := m.GetMockJobsAPI()
166164
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
@@ -231,7 +229,6 @@ func TestJobRunnerRestartForContinuousUnpausedJobs(t *testing.T) {
231229
b.SetWorkpaceClient(m.WorkspaceClient)
232230

233231
ctx := cmdio.MockDiscard(context.Background())
234-
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
235232

236233
jobApi := m.GetMockJobsAPI()
237234

bundle/run/pipeline.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,9 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
106106

107107
// setup progress logger and tracker to query events
108108
updateTracker := progress.NewUpdateTracker(pipelineID, updateID, w)
109-
progressLogger, ok := cmdio.FromContext(ctx)
110-
if !ok {
111-
return nil, errors.New("no progress logger found")
112-
}
113109

114110
// Log the pipeline update URL as soon as it is available.
115-
progressLogger.Log(progress.NewPipelineUpdateUrlEvent(w.Config.Host, updateID, pipelineID))
111+
cmdio.Log(ctx, progress.NewPipelineUpdateUrlEvent(w.Config.Host, updateID, pipelineID))
116112

117113
if opts.NoWait {
118114
return &output.PipelineOutput{
@@ -129,7 +125,7 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
129125
return nil, err
130126
}
131127
for _, event := range events {
132-
progressLogger.Log(&event)
128+
cmdio.Log(ctx, &event)
133129
log.Info(ctx, event.String())
134130
}
135131

bundle/run/pipeline_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/databricks/cli/bundle/config"
1010
"github.com/databricks/cli/bundle/config/resources"
1111
"github.com/databricks/cli/libs/cmdio"
12-
"github.com/databricks/cli/libs/flags"
1312
sdk_config "github.com/databricks/databricks-sdk-go/config"
1413
"github.com/databricks/databricks-sdk-go/experimental/mocks"
1514
"github.com/databricks/databricks-sdk-go/service/pipelines"
@@ -76,7 +75,6 @@ func TestPipelineRunnerRestart(t *testing.T) {
7675
b.SetWorkpaceClient(m.WorkspaceClient)
7776

7877
ctx := cmdio.MockDiscard(context.Background())
79-
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
8078

8179
mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
8280
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {

bundle/statemgmt/state_push_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/databricks/cli/bundle"
1010
"github.com/databricks/cli/bundle/config"
1111
mockfiler "github.com/databricks/cli/internal/mocks/libs/filer"
12+
"github.com/databricks/cli/libs/cmdio"
1213
"github.com/databricks/cli/libs/filer"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/mock"
@@ -51,7 +52,7 @@ func TestStatePush(t *testing.T) {
5152
identityFiler(mock),
5253
}
5354

54-
ctx := context.Background()
55+
ctx := cmdio.MockDiscard(context.Background())
5556
b := statePushTestBundle(t)
5657

5758
// Write a stale local state file.

cmd/root/root.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313

1414
"github.com/databricks/cli/internal/build"
1515
"github.com/databricks/cli/libs/cmdctx"
16-
"github.com/databricks/cli/libs/cmdio"
1716
"github.com/databricks/cli/libs/dbr"
1817
"github.com/databricks/cli/libs/log"
1918
"github.com/databricks/cli/libs/telemetry"
@@ -140,9 +139,7 @@ Stack Trace:
140139
// Run the command
141140
cmd, err = cmd.ExecuteContextC(ctx)
142141
if err != nil && !errors.Is(err, ErrAlreadyPrinted) {
143-
// If cmdio logger initialization succeeds, then this function logs with the
144-
// initialized cmdio logger, otherwise with the default cmdio logger
145-
cmdio.LogError(cmd.Context(), err)
142+
fmt.Fprintf(cmd.ErrOrStderr(), "Error: %s\n", err.Error())
146143
}
147144

148145
// Log exit status and error

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ import (
1414
"testing"
1515
"time"
1616

17+
"github.com/databricks/cli/libs/cmdio"
1718
"github.com/gorilla/websocket"
1819
"github.com/stretchr/testify/assert"
1920
"github.com/stretchr/testify/require"
2021
)
2122

2223
func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration) *httptest.Server {
23-
ctx := t.Context()
24+
ctx := cmdio.MockDiscard(t.Context())
2425
connections := NewConnectionsManager(maxClients, shutdownDelay)
2526
proxyServer := NewProxyServer(ctx, connections, func(ctx context.Context) *exec.Cmd {
2627
// 'cat' command reads each line from stdin and sends it to stdout, so we can test end-to-end proxying.
@@ -30,7 +31,7 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
3031
}
3132

3233
func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) (io.WriteCloser, *testBuffer) {
33-
ctx := t.Context()
34+
ctx := cmdio.MockDiscard(t.Context())
3435
clientInput, clientInputWriter := io.Pipe()
3536
clientOutput := newTestBuffer(t)
3637
wsURL := "ws" + serverURL[4:]

experimental/ssh/internal/setup/setup_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/databricks/cli/libs/cmdio"
1213
"github.com/databricks/databricks-sdk-go/experimental/mocks"
1314
"github.com/databricks/databricks-sdk-go/service/compute"
1415
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
)
1718

1819
func TestValidateClusterAccess_SingleUser(t *testing.T) {
19-
ctx := context.Background()
20+
ctx := cmdio.MockDiscard(context.Background())
2021
m := mocks.NewMockWorkspaceClient(t)
2122
clustersAPI := m.GetMockClustersAPI()
2223

@@ -29,7 +30,7 @@ func TestValidateClusterAccess_SingleUser(t *testing.T) {
2930
}
3031

3132
func TestValidateClusterAccess_InvalidAccessMode(t *testing.T) {
32-
ctx := context.Background()
33+
ctx := cmdio.MockDiscard(context.Background())
3334
m := mocks.NewMockWorkspaceClient(t)
3435
clustersAPI := m.GetMockClustersAPI()
3536

@@ -43,7 +44,7 @@ func TestValidateClusterAccess_InvalidAccessMode(t *testing.T) {
4344
}
4445

4546
func TestValidateClusterAccess_ClusterNotFound(t *testing.T) {
46-
ctx := context.Background()
47+
ctx := cmdio.MockDiscard(context.Background())
4748
m := mocks.NewMockWorkspaceClient(t)
4849
clustersAPI := m.GetMockClustersAPI()
4950

@@ -315,7 +316,7 @@ func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) {
315316
}
316317

317318
func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
318-
ctx := context.Background()
319+
ctx := cmdio.MockDiscard(context.Background())
319320
tmpDir := t.TempDir()
320321
configPath := filepath.Join(tmpDir, "ssh_config")
321322

@@ -349,7 +350,7 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
349350
}
350351

351352
func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
352-
ctx := context.Background()
353+
ctx := cmdio.MockDiscard(context.Background())
353354
tmpDir := t.TempDir()
354355
configPath := filepath.Join(tmpDir, "ssh_config")
355356

@@ -393,7 +394,7 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
393394
}
394395

395396
func TestSetup_DoesNotOverrideExistingHost(t *testing.T) {
396-
ctx := context.Background()
397+
ctx := cmdio.MockDiscard(context.Background())
397398
tmpDir := t.TempDir()
398399
configPath := filepath.Join(tmpDir, "ssh_config")
399400

libs/cmdio/compat.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package cmdio
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"strings"
8+
9+
"github.com/manifoldco/promptui"
10+
)
11+
12+
/*
13+
Temporary compatibility layer for the progress logger interfaces.
14+
*/
15+
16+
// Log is a compatibility layer for the progress logger interfaces.
17+
// It writes the string representation of the stringer to the error writer.
18+
func Log(ctx context.Context, str fmt.Stringer) {
19+
LogString(ctx, str.String())
20+
}
21+
22+
// LogString is a compatibility layer for the progress logger interfaces.
23+
// It writes the string to the error writer.
24+
func LogString(ctx context.Context, str string) {
25+
c := fromContext(ctx)
26+
_, _ = io.WriteString(c.err, str)
27+
_, _ = io.WriteString(c.err, "\n")
28+
}
29+
30+
// readLine reads a line from the reader and returns it without the trailing newline characters.
31+
// It is unbuffered because cmdio's stdin is also unbuffered.
32+
// If we were to add a [bufio.Reader] to the mix, we would need to update the other uses of the reader.
33+
// Once cmdio's stdio is made to be buffered, this function can be removed.
34+
func readLine(r io.Reader) (string, error) {
35+
var b strings.Builder
36+
buf := make([]byte, 1)
37+
for {
38+
n, err := r.Read(buf)
39+
if n > 0 {
40+
if buf[0] == '\n' {
41+
break
42+
}
43+
if buf[0] != '\r' {
44+
b.WriteByte(buf[0])
45+
}
46+
}
47+
if err != nil {
48+
if b.Len() == 0 {
49+
return "", err
50+
}
51+
break
52+
}
53+
}
54+
return b.String(), nil
55+
}
56+
57+
// Ask is a compatibility layer for the progress logger interfaces.
58+
// It prompts the user with a question and returns the answer.
59+
func Ask(ctx context.Context, question, defaultVal string) (string, error) {
60+
c := fromContext(ctx)
61+
62+
// Add default value to question prompt.
63+
if defaultVal != "" {
64+
question += fmt.Sprintf(` [%s]`, defaultVal)
65+
}
66+
question += `: `
67+
68+
// Print prompt.
69+
_, err := io.WriteString(c.err, question)
70+
if err != nil {
71+
return "", err
72+
}
73+
74+
// Read user input. Trim new line characters.
75+
ans, err := readLine(c.in)
76+
if err != nil {
77+
return "", err
78+
}
79+
80+
// Return default value if user just presses enter.
81+
if ans == "" {
82+
return defaultVal, nil
83+
}
84+
85+
return ans, nil
86+
}
87+
88+
// AskYesOrNo is a compatibility layer for the progress logger interfaces.
89+
// It prompts the user with a question and returns the answer.
90+
func AskYesOrNo(ctx context.Context, question string) (bool, error) {
91+
ans, err := Ask(ctx, question+" [y/n]", "")
92+
if err != nil {
93+
return false, err
94+
}
95+
return ans == "y", nil
96+
}
97+
98+
func splitAtLastNewLine(s string) (string, string) {
99+
// Split at the newline character
100+
if i := strings.LastIndex(s, "\n"); i != -1 {
101+
return s[:i+1], s[i+1:]
102+
}
103+
// Return the original string if no newline found
104+
return "", s
105+
}
106+
107+
// AskSelect is a compatibility layer for the progress logger interfaces.
108+
// It prompts the user with a question and returns the answer.
109+
func AskSelect(ctx context.Context, question string, choices []string) (string, error) {
110+
c := fromContext(ctx)
111+
112+
// Promptui does not support multiline prompts. So we split the question.
113+
first, last := splitAtLastNewLine(question)
114+
_, err := io.WriteString(c.err, first)
115+
if err != nil {
116+
return "", err
117+
}
118+
119+
// Note: by default this prompt uses os.Stdin and os.Stdout.
120+
// This is contrary to the rest of the original progress logger
121+
// functions that write to stderr.
122+
prompt := promptui.Select{
123+
Label: last,
124+
Items: choices,
125+
HideHelp: true,
126+
Templates: &promptui.SelectTemplates{
127+
Label: "{{.}}: ",
128+
Selected: last + ": {{.}}",
129+
},
130+
}
131+
132+
_, ans, err := prompt.Run()
133+
if err != nil {
134+
return "", err
135+
}
136+
return ans, nil
137+
}

0 commit comments

Comments
 (0)