diff --git a/libmachine/shell/shell_windows.go b/libmachine/shell/shell_windows.go index 89cd2c8b0a..3cb02b0206 100644 --- a/libmachine/shell/shell_windows.go +++ b/libmachine/shell/shell_windows.go @@ -80,5 +80,7 @@ func Detect() (string, error) { return "fish", nil } + shell = strings.TrimSuffix(shell, filepath.Ext(shell)) + return filepath.Base(shell), nil } diff --git a/libmachine/shell/shell_windows_test.go b/libmachine/shell/shell_windows_test.go index 81c0c50705..9fe0601c5b 100644 --- a/libmachine/shell/shell_windows_test.go +++ b/libmachine/shell/shell_windows_test.go @@ -17,6 +17,16 @@ func TestDetect(t *testing.T) { assert.NoError(t, err) } +func TestDetectOnSSH(t *testing.T) { + defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) + os.Setenv("SHELL", "c:\\windows\\system32\\windowspowershell\\v1.0\\powershell.exe") + + shell, err := Detect() + + assert.Equal(t, "powershell", shell) + assert.NoError(t, err) +} + func TestGetNameAndItsPpidOfCurrent(t *testing.T) { shell, shellppid, err := getNameAndItsPpid(os.Getpid())