diff --git a/go.mod b/go.mod index 9a0b009a..157f5a17 100644 --- a/go.mod +++ b/go.mod @@ -109,6 +109,7 @@ require ( github.com/skeema/knownhosts v1.1.0 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect github.com/subosito/gotenv v1.4.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index ae1e3528..4ded0f0a 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -3,7 +3,6 @@ package ls import ( "fmt" - "os" "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" @@ -39,7 +38,7 @@ func NewCmdLs(t *terminal.Terminal, loginLsStore LsStore, noLoginLsStore LsStore cmd := &cobra.Command{ Annotations: map[string]string{"context": ""}, - Use: "ls", + Use: "ls [orgs|organizations|workspaces|users|hosts]", Aliases: []string{"list"}, Short: "List instances within active org", Long: "List instances within your active org. List all instances if no active org is set.", @@ -90,8 +89,14 @@ func NewCmdLs(t *terminal.Terminal, loginLsStore LsStore, noLoginLsStore LsStore return nil }, - Args: cmderrors.TransformToValidationError(cobra.MinimumNArgs(0)), - ValidArgs: []string{"orgs", "workspaces"}, + Args: cmderrors.TransformToValidationError(func(cmd *cobra.Command, args []string) error { + // Allow 0 or 1 argument, and only valid ones + if len(args) > 1 { + return fmt.Errorf("this command accepts only zero or one argument") + } + return cobra.OnlyValidArgs(cmd, args) + }), + ValidArgs: []string{"org", "orgs", "organization", "organizations", "workspace", "workspaces", "user", "users", "host", "hosts"}, RunE: func(cmd *cobra.Command, args []string) error { err := RunLs(t, loginLsStore, args, org, showAll) if err != nil { @@ -162,26 +167,23 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, return breverrors.WrapAndTrace(err) } - org, err := getOrgForRunLs(lsStore, orgflag) + orgEntity, err := getOrgForRunLs(lsStore, orgflag) if err != nil { return breverrors.WrapAndTrace(err) } - if len(args) > 1 { - return breverrors.NewValidationError("too many args provided") - } - if len(args) == 1 { //nolint:gocritic // don't want to switch - err = handleLsArg(ls, args[0], user, org, showAll) + if len(args) == 0 { + // No argument provided, list workspaces by default + err = ls.RunWorkspaces(orgEntity, user, showAll) if err != nil { return breverrors.WrapAndTrace(err) } - } else if len(args) == 0 { - err = ls.RunWorkspaces(org, user, showAll) + } else { + // Handle the provided argument + err = handleLsArg(ls, args[0], user, orgEntity, showAll) if err != nil { return breverrors.WrapAndTrace(err) } - } else { - return fmt.Errorf("unhandle ls arguments") } return nil @@ -189,31 +191,40 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization, showAll bool) error { // todo refactor this to cmd.register - //nolint:gocritic // idk how to write this as a switch - if util.IsSingularOrPlural(arg, "org") || util.IsSingularOrPlural(arg, "organization") { // handle org, orgs, and organization(s) - err := ls.RunOrgs() - if err != nil { + switch { + case util.IsSingularOrPlural(arg, "org") || util.IsSingularOrPlural(arg, "organization"): + // Handle organizations + if err := ls.RunOrgs(); err != nil { return breverrors.WrapAndTrace(err) } - return nil - } else if util.IsSingularOrPlural(arg, "workspace") { - err := ls.RunWorkspaces(org, user, showAll) - if err != nil { + case util.IsSingularOrPlural(arg, "workspace"): + // Handle workspaces + if err := ls.RunWorkspaces(org, user, showAll); err != nil { return breverrors.WrapAndTrace(err) } - } else if util.IsSingularOrPlural(arg, "user") && featureflag.IsAdmin(user.GlobalUserType) { - err := ls.RunUser(showAll) - if err != nil { - return breverrors.WrapAndTrace(err) + case util.IsSingularOrPlural(arg, "user"): + // Handle users, only if the user is an admin + if featureflag.IsAdmin(user.GlobalUserType) { + if err := ls.RunUser(showAll); err != nil { + return breverrors.WrapAndTrace(err) + } + } else { + return breverrors.NewValidationError("user management is only available for admins") } - return nil - } else if util.IsSingularOrPlural(arg, "host") && featureflag.IsAdmin(user.GlobalUserType) { - err := ls.RunHosts(org) - if err != nil { - return breverrors.WrapAndTrace(err) + case util.IsSingularOrPlural(arg, "host"): + // Handle hosts, only if the user is an admin + if featureflag.IsAdmin(user.GlobalUserType) { + if err := ls.RunHosts(org); err != nil { + return breverrors.WrapAndTrace(err) + } + } else { + return breverrors.NewValidationError("host management is only available for admins") } - return nil + default: + // If the argument is not recognized, return a validation error + return breverrors.NewValidationError("unrecognized argument") } + return nil } @@ -412,7 +423,7 @@ func displayProjects(t *terminal.Terminal, orgName string, projects []virtualpro if len(projects) > 0 { fmt.Print("\n") t.Vprintf("%d other projects in Org "+t.Yellow(orgName)+"\n", len(projects)) - displayProjectsTable(projects) + displayProjectsTable(t, projects) fmt.Print("\n") t.Vprintf(t.Green("Join a project:\n") + @@ -438,7 +449,7 @@ func getBrevTableOptions() table.Options { func displayWorkspacesTable(t *terminal.Terminal, workspaces []entity.Workspace, userID string) { ta := table.NewWriter() - ta.SetOutputMirror(os.Stdout) + ta.SetOutputMirror(t.Out()) ta.Style().Options = getBrevTableOptions() header := table.Row{"Name", "Status", "ID", "Machine"} if enableSSHCol { @@ -471,7 +482,7 @@ func getWorkspaceDisplayStatus(w entity.Workspace) string { func displayOrgTable(t *terminal.Terminal, orgs []entity.Organization, currentOrg *entity.Organization) { ta := table.NewWriter() - ta.SetOutputMirror(os.Stdout) + ta.SetOutputMirror(t.Out()) ta.Style().Options = getBrevTableOptions() header := table.Row{"NAME", "ID"} ta.AppendHeader(header) @@ -485,9 +496,9 @@ func displayOrgTable(t *terminal.Terminal, orgs []entity.Organization, currentOr ta.Render() } -func displayProjectsTable(projects []virtualproject.VirtualProject) { +func displayProjectsTable(t *terminal.Terminal, projects []virtualproject.VirtualProject) { ta := table.NewWriter() - ta.SetOutputMirror(os.Stdout) + ta.SetOutputMirror(t.Out()) ta.Style().Options = getBrevTableOptions() header := table.Row{"NAME", "MEMBERS"} ta.AppendHeader(header) diff --git a/pkg/cmd/ls/ls_test.go b/pkg/cmd/ls/ls_test.go index 16ee671a..ed65085d 100644 --- a/pkg/cmd/ls/ls_test.go +++ b/pkg/cmd/ls/ls_test.go @@ -1 +1,280 @@ package ls + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockLsStore is a mock implementation of the LsStore interface +type MockLsStore struct { + mock.Mock +} + +func (m *MockLsStore) UpdateUser(userID string, updatedUser *entity.UpdateUser) (*entity.User, error) { + args := m.Called(userID, updatedUser) + return args.Get(0).(*entity.User), args.Error(1) +} + +func (m *MockLsStore) GetCurrentWorkspaceID() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func (m *MockLsStore) GetAllWorkspaces(options *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + args := m.Called() + return args.Get(0).([]entity.Workspace), args.Error(1) +} + +func (m *MockLsStore) GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + args := m.Called(organizationID, options) + return args.Get(0).([]entity.Workspace), args.Error(1) +} + +func (m *MockLsStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + args := m.Called() + return args.Get(0).(*entity.Organization), args.Error(1) +} + +func (m *MockLsStore) GetCurrentUser() (*entity.User, error) { + args := m.Called() + return args.Get(0).(*entity.User), args.Error(1) +} + +func (m *MockLsStore) GetUsers(queryParams map[string]string) ([]entity.User, error) { + args := m.Called(queryParams) + return args.Get(0).([]entity.User), args.Error(1) +} + +func (m *MockLsStore) GetWorkspace(workspaceID string) (*entity.Workspace, error) { + args := m.Called(workspaceID) + return args.Get(0).(*entity.Workspace), args.Error(1) +} + +func (m *MockLsStore) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) { + args := m.Called(options) + return args.Get(0).([]entity.Organization), args.Error(1) +} + +func TestLsCommand_RunLs_NoArgs(t *testing.T) { + mockTerminal, outBuffer, verboseBuffer, errBuffer := terminal.NewTestTerminal() + + mockLsStore := new(MockLsStore) + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + testWorkspaces := []entity.Workspace{ + { + ID: "workspace-1", + Name: "Workspace 1", + CreatedByUserID: "test-user-id", + Status: entity.Running, + }, + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetActiveOrganizationOrDefault").Return(testOrg, nil) + mockLsStore.On("GetWorkspaces", "test-org-id", (*store.GetWorkspacesOptions)(nil)).Return(testWorkspaces, nil) + mockLsStore.On("GetOrganizations", (*store.GetOrganizationsOptions)(nil)).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + // Execute the command without any arguments + cmd.SetArgs([]string{}) + + err := cmd.Execute() + assert.NoError(t, err) + assert.Contains(t, verboseBuffer.String(), "You have 1 instances in Org test-org") + assert.Contains(t, outBuffer.String(), "NAME STATUS ID MACHINE \n Workspace 1 RUNNING workspace-1 (gpu) \n") +} + +func TestLsCommand_RunLs_OrgArg(t *testing.T) { + mockTerminal, outBuffer, verboseBuffer, errBuffer := terminal.NewTestTerminal() + + mockLsStore := new(MockLsStore) + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetOrganizations", (*store.GetOrganizationsOptions)(nil)).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + mockLsStore.On("GetActiveOrganizationOrDefault").Return(testOrg, nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + + // Execute the command with the "org" argument + cmd.SetArgs([]string{"org"}) + + err := cmd.Execute() + assert.NoError(t, err) + assert.Contains(t, verboseBuffer.String(), "Your organizations:") + assert.Contains(t, outBuffer.String(), "* test-org") +} + +func TestLsCommand_RunLs_InvalidArg(t *testing.T) { + mockTerminal, outBuffer, _, errBuffer := terminal.NewTestTerminal() + + mockLsStore := new(MockLsStore) + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetOrganizations", (*store.GetOrganizationsOptions)(nil)).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + mockLsStore.On("GetActiveOrganizationOrDefault").Return(testOrg, nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + + // Execute the command with the "org" argument + cmd.SetArgs([]string{"rubbish"}) + + err := cmd.Execute() + assert.Error(t, err) + assert.Contains(t, errBuffer.String(), `invalid argument`) + assert.Contains(t, outBuffer.String(), "Usage:\n") +} + +func TestLsCommand_RunLs_InvalidOrg(t *testing.T) { + mockTerminal, outBuffer, _, errBuffer := terminal.NewTestTerminal() + + mockLsStore := new(MockLsStore) + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetOrganizations", &store.GetOrganizationsOptions{Name: "invalid-org"}).Return([]entity.Organization{}, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + + // Execute the command with an invalid --org flag + cmd.SetArgs([]string{"--org", "invalid-org"}) + + err := cmd.Execute() + + assert.Error(t, err) + assert.Contains(t, errBuffer.String(), "no org found with name invalid-org") +} + +func TestLsCommand_RunLs_WithOrgFlag(t *testing.T) { + mockTerminal, outBuffer, verboseBuffer, errBuffer := terminal.NewTestTerminal() + mockLsStore := new(MockLsStore) + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + testWorkspaces := []entity.Workspace{ + { + ID: "workspace-1", + Name: "Workspace 1", + CreatedByUserID: "test-user-id", + Status: entity.Running, + }, + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetOrganizations", (*store.GetOrganizationsOptions)(nil)).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("GetOrganizations", &store.GetOrganizationsOptions{Name: "test-org"}).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("GetWorkspaces", "test-org-id", (*store.GetWorkspacesOptions)(nil)).Return(testWorkspaces, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + + // Execute the command with the --org flag + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + cmd.SetArgs([]string{"--org", "test-org"}) + + err := cmd.Execute() + assert.NoError(t, err) + assert.Contains(t, verboseBuffer.String(), "You have 1 instances in Org test-org") + assert.Contains(t, outBuffer.String(), "Workspace 1") +} + +func TestLsCommand_RunLs_AllFlag(t *testing.T) { + mockTerminal, outBuffer, verboseBuffer, errBuffer := terminal.NewTestTerminal() + + mockLsStore := new(MockLsStore) + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + allWorkspaces := []entity.Workspace{ + { + ID: "workspace-1", + Name: "Workspace 1", + CreatedByUserID: "test-user-id", + Status: entity.Running, + }, + { + ID: "workspace-2", + Name: "Workspace 2", + CreatedByUserID: "test-user-id-2", + Status: entity.Stopped, + }, + } + + mockLsStore.On("GetCurrentUser").Return(testUser, nil) + mockLsStore.On("GetActiveOrganizationOrDefault").Return(testOrg, nil) + mockLsStore.On("GetWorkspaces", "test-org-id", (*store.GetWorkspacesOptions)(nil)).Return(allWorkspaces, nil) + mockLsStore.On("GetOrganizations", (*store.GetOrganizationsOptions)(nil)).Return([]entity.Organization{*testOrg}, nil) + mockLsStore.On("UpdateUser", testUser.ID, mock.AnythingOfType("*entity.UpdateUser")).Return(testUser, nil) + mockLsStore.On("GetCurrentWorkspaceID").Return("workspace-1", nil) + + cmd := NewCmdLs(mockTerminal, mockLsStore, mockLsStore) + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + // Execute the command with the --all flag + cmd.SetArgs([]string{"--all"}) + + err := cmd.Execute() + assert.NoError(t, err) + assert.Contains(t, verboseBuffer.String(), "You have 1 instances in Org test-org\nno other projects in Org test-org\n%!(EXTRA int=0)Invite a teamate:\n\tbrev invite") + assert.Contains(t, outBuffer.String(), " NAME STATUS ID MACHINE \n Workspace 1 RUNNING workspace-1 (gpu) \n") +} diff --git a/pkg/cmd/shell/shell.go b/pkg/cmd/shell/shell.go index 173ff8c5..b73a00bc 100644 --- a/pkg/cmd/shell/shell.go +++ b/pkg/cmd/shell/shell.go @@ -40,6 +40,10 @@ type ShellStore interface { } func NewCmdShell(t *terminal.Terminal, store ShellStore, noLoginStartStore ShellStore) *cobra.Command { + return newCmdShellWithDeps(t, store, noLoginStartStore, exec.Command) +} + +func newCmdShellWithDeps(t *terminal.Terminal, store ShellStore, noLoginStartStore ShellStore, executor cmdExecutor) *cobra.Command { var runRemoteCMD bool var directory string var host bool @@ -54,6 +58,11 @@ func NewCmdShell(t *terminal.Terminal, store ShellStore, noLoginStartStore Shell Args: cmderrors.TransformToValidationError(cmderrors.TransformToValidationError(cobra.ExactArgs(1))), ValidArgsFunction: completions.GetAllWorkspaceNameCompletionHandler(noLoginStartStore, t), RunE: func(cmd *cobra.Command, args []string) error { + // Check if SSH client is available + if err := checkSSHClient(executor); err != nil { + return breverrors.WrapAndTrace(err) + } + err := runShellCommand(t, store, args[0], directory, host) if err != nil { return breverrors.WrapAndTrace(err) @@ -138,6 +147,17 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID, return nil } +type cmdExecutor func(name string, arg ...string) *exec.Cmd + +// checkSSHClient checks if the SSH client is available on the system. +func checkSSHClient(executor cmdExecutor) error { + cmd := executor("ssh", "-V") + if err := cmd.Run(); err != nil { + return fmt.Errorf("SSH client is not installed or not available in PATH. Please install SSH client to use this feature") + } + return nil +} + func waitForSSHToBeAvailable(sshAlias string, s *spinner.Spinner) error { counter := 0 s.Suffix = " waiting for SSH connection to be available" @@ -151,10 +171,16 @@ func waitForSSHToBeAvailable(sshAlias string, s *spinner.Spinner) error { } outputStr := string(out) - stdErr := strings.Split(outputStr, "\n")[1] + stdErr := strings.Split(outputStr, "\n") + var stdErrMessage string + if len(stdErr) > 1 { + stdErrMessage = stdErr[1] + } else { + stdErrMessage = outputStr // Fallback if splitting fails + } - if counter == 120 || !store.SatisfactorySSHErrMessage(stdErr) { - return breverrors.WrapAndTrace(errors.New("\n" + stdErr)) + if counter == 120 || !store.SatisfactorySSHErrMessage(stdErrMessage) { + return breverrors.WrapAndTrace(errors.New("\n" + stdErrMessage)) } counter++ @@ -212,8 +238,8 @@ func pollUntil(s *spinner.Spinner, wsid string, state string, shellStore ShellSt isReady := false s.Suffix = waitMsg s.Start() + defer s.Stop() for !isReady { - time.Sleep(5 * time.Second) ws, err := shellStore.GetWorkspace(wsid) if err != nil { return breverrors.WrapAndTrace(err) @@ -222,6 +248,8 @@ func pollUntil(s *spinner.Spinner, wsid string, state string, shellStore ShellSt if ws.Status == state { isReady = true } + + time.Sleep(5 * time.Second) } return nil } diff --git a/pkg/cmd/shell/shell_test.go b/pkg/cmd/shell/shell_test.go index b1f847e4..16825fc7 100644 --- a/pkg/cmd/shell/shell_test.go +++ b/pkg/cmd/shell/shell_test.go @@ -1 +1,226 @@ package shell + +import ( + "errors" + "os/exec" + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockShellStore struct { + mock.Mock + store.FileStore +} + +func (m *MockShellStore) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) { + args := m.Called(options) + return args.Get(0).([]entity.Organization), args.Error(1) +} + +func (m *MockShellStore) GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + args := m.Called(organizationID, options) + return args.Get(0).([]entity.Workspace), args.Error(1) +} + +func (m *MockShellStore) StartWorkspace(workspaceID string) (*entity.Workspace, error) { + args := m.Called(workspaceID) + return args.Get(0).(*entity.Workspace), args.Error(1) +} + +func (m *MockShellStore) GetWorkspace(workspaceID string) (*entity.Workspace, error) { + args := m.Called(workspaceID) + return args.Get(0).(*entity.Workspace), args.Error(1) +} + +func (m *MockShellStore) GetCurrentUserKeys() (*entity.UserKeys, error) { + args := m.Called() + return args.Get(0).(*entity.UserKeys), args.Error(1) +} + +func (m *MockShellStore) GetCurrentUser() (*entity.User, error) { + args := m.Called() + return args.Get(0).(*entity.User), args.Error(1) +} + +func (m *MockShellStore) GetWorkspaceByNameOrIDErr(orgID, workspaceNameOrID string) (*entity.Workspace, error) { + args := m.Called(orgID, workspaceNameOrID) + return args.Get(0).(*entity.Workspace), args.Error(1) +} + +func (m *MockShellStore) CopyBin(targetBin string) error { + args := m.Called(targetBin) + return args.Error(0) +} + +func (m *MockShellStore) DoesJetbrainsFilePathExist() (bool, error) { + args := m.Called() + return args.Get(0).(bool), args.Error(1) +} + +func (m *MockShellStore) DownloadBinary(url string, target string) error { + args := m.Called(url, target) + return args.Error(0) +} + +func (m *MockShellStore) FileExists(filepath string) (bool, error) { + args := m.Called(filepath) + return args.Get(0).(bool), args.Error(1) +} + +func (m *MockShellStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + args := m.Called() + return args.Get(0).(*entity.Organization), args.Error(1) +} + +func (m *MockShellStore) GetBrevSSHConfigPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetContextWorkspaces() ([]entity.Workspace, error) { + args := m.Called() + return args.Get(0).([]entity.Workspace), args.Error(1) +} + +func (m *MockShellStore) WriteBrevSSHConfig(config string) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockShellStore) GetUserSSHConfig() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) WriteUserSSHConfig(config string) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockShellStore) GetPrivateKeyPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetUserSSHConfigPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetJetBrainsConfigPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetJetBrainsConfig() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) WriteJetBrainsConfig(config string) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockShellStore) GetWSLHostUserSSHConfigPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetWindowsDir() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) WriteBrevSSHConfigWSL(config string) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockShellStore) GetFileAsString(path string) (string, error) { + args := m.Called(path) + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetWSLHostBrevSSHConfigPath() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) GetWSLUserSSHConfig() (string, error) { + args := m.Called() + return args.Get(0).(string), args.Error(1) +} + +func (m *MockShellStore) WriteWSLUserSSHConfig(config string) error { + args := m.Called(config) + return args.Error(0) +} + +func (m *MockShellStore) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) { + args := m.Called(orgID, nameOrID) + return args.Get(0).([]entity.Workspace), args.Error(1) +} + +func TestNewCmdShell_SSHClientNotInstalled(t *testing.T) { + mockTerminal, outBuffer, _, errBuffer := terminal.NewTestTerminal() + mockStore := new(MockShellStore) + + cmd := newCmdShellWithDeps(mockTerminal, mockStore, mockStore, func(name string, arg ...string) *exec.Cmd { + return exec.Command("false") + }) + + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + cmd.SetArgs([]string{"workspace-1"}) + + // Run the command + err := cmd.Execute() + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "SSH client is not installed or not available in PATH") +} + +func TestNewCmdShell_SSHClientWorkspaceNotFound(t *testing.T) { + mockTerminal, outBuffer, _, errBuffer := terminal.NewTestTerminal() + mockStore := new(MockShellStore) + + testUser := &entity.User{ + ID: "test-user-id", + Name: "Test User", + } + testOrg := &entity.Organization{ + ID: "test-org-id", + Name: "test-org", + } + testWorkspaces := []entity.Workspace{ + { + ID: "workspace-1", + Name: "Workspace 1", + CreatedByUserID: "test-user-id", + Status: entity.Running, + }, + } + mockStore.On("GetCurrentUser").Return(testUser, nil) + mockStore.On("GetActiveOrganizationOrDefault").Return(testOrg, nil) + mockStore.On("GetWorkspaceByNameOrID", testOrg.ID, "workspace-1").Return(testWorkspaces, nil) + mockStore.On("GetWorkspace", "workspace-1").Return((*entity.Workspace)(nil), errors.New("workspace not found")) + + cmd := NewCmdShell(mockTerminal, mockStore, mockStore) + + cmd.SetOut(outBuffer) + cmd.SetErr(errBuffer) + cmd.SetArgs([]string{"workspace-1"}) + + err := cmd.Execute() + + assert.Error(t, err) + assert.Contains(t, errBuffer.String(), "waiting for instance to be ready") + assert.Contains(t, errBuffer.String(), "workspace not found") +} diff --git a/pkg/terminal/terminal.go b/pkg/terminal/terminal.go index a492fb56..7763af9c 100644 --- a/pkg/terminal/terminal.go +++ b/pkg/terminal/terminal.go @@ -2,6 +2,7 @@ package terminal import ( + "bytes" "fmt" "io" "os" @@ -50,12 +51,26 @@ func New() (t *Terminal) { } } -func (t *Terminal) SetVerbose(verbose bool) { - if verbose { - t.out = os.Stdout - } else { - t.out = silentWriter{} - } +// NewTestTerminal creates a new Terminal instance with buffers for testing +func NewTestTerminal() (*Terminal, *bytes.Buffer, *bytes.Buffer, *bytes.Buffer) { + outBuffer := &bytes.Buffer{} + verboseBuffer := &bytes.Buffer{} + errBuffer := &bytes.Buffer{} + + return &Terminal{ + out: outBuffer, + verbose: verboseBuffer, + err: errBuffer, + Green: color.New(color.FgGreen).SprintfFunc(), + Yellow: color.New(color.FgYellow).SprintfFunc(), + Red: color.New(color.FgRed).SprintfFunc(), + Blue: color.New(color.FgBlue).SprintfFunc(), + White: color.New(color.FgWhite, color.Bold).SprintfFunc(), + }, outBuffer, verboseBuffer, errBuffer +} + +func (t *Terminal) Out() io.Writer { + return t.out } func (t *Terminal) Print(a string) { @@ -109,7 +124,7 @@ func (w silentWriter) Write(_ []byte) (n int, err error) { } func (t *Terminal) NewSpinner() *spinner.Spinner { - spinner := spinner.New(spinner.CharSets[11], 100*time.Millisecond, spinner.WithWriter(os.Stderr)) + spinner := spinner.New(spinner.CharSets[11], 100*time.Millisecond, spinner.WithWriter(t.err)) err := spinner.Color("cyan", "bold") if err != nil { t.Errprint(err, "")