|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "fmt" |
| 7 | + "net" |
| 8 | + "os" |
| 9 | + "strings" |
| 10 | + "testing" |
| 11 | + |
| 12 | + "github.com/argoproj/argo-cd/v2/reposerver/askpass" |
| 13 | + "github.com/argoproj/argo-cd/v2/util/git" |
| 14 | + "github.com/spf13/cobra" |
| 15 | + "github.com/stretchr/testify/assert" |
| 16 | + "google.golang.org/grpc" |
| 17 | + "google.golang.org/grpc/credentials/insecure" |
| 18 | + "google.golang.org/grpc/test/bufconn" |
| 19 | +) |
| 20 | + |
| 21 | +const bufSize = 1024 * 1024 |
| 22 | + |
| 23 | +var lis *bufconn.Listener |
| 24 | + |
| 25 | +func init() { |
| 26 | + lis = bufconn.Listen(bufSize) |
| 27 | + s := grpc.NewServer() |
| 28 | + askpass.RegisterAskPassServiceServer(s, &mockAskPassServer{}) |
| 29 | + go func() { |
| 30 | + _ = s.Serve(lis) |
| 31 | + }() |
| 32 | +} |
| 33 | + |
| 34 | +type mockAskPassServer struct { |
| 35 | + askpass.UnimplementedAskPassServiceServer |
| 36 | +} |
| 37 | + |
| 38 | +func (m *mockAskPassServer) GetCredentials(ctx context.Context, req *askpass.CredentialsRequest) (*askpass.CredentialsResponse, error) { |
| 39 | + return &askpass.CredentialsResponse{Username: "testuser", Password: "testpassword"}, nil |
| 40 | +} |
| 41 | + |
| 42 | +func bufDialer(context.Context, string) (net.Conn, error) { |
| 43 | + return lis.Dial() |
| 44 | +} |
| 45 | + |
| 46 | +func NewTestCommand() *cobra.Command { |
| 47 | + cmd := NewAskPassCommand() |
| 48 | + cmd.Run = func(c *cobra.Command, args []string) { |
| 49 | + ctx := c.Context() |
| 50 | + if len(args) != 1 { |
| 51 | + fmt.Fprintf(c.ErrOrStderr(), "expected 1 argument, got %d\n", len(args)) |
| 52 | + return |
| 53 | + } |
| 54 | + nonce := os.Getenv(git.ASKPASS_NONCE_ENV) |
| 55 | + if nonce == "" { |
| 56 | + fmt.Fprintf(c.ErrOrStderr(), "%s is not set\n", git.ASKPASS_NONCE_ENV) |
| 57 | + return |
| 58 | + } |
| 59 | + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) |
| 60 | + if err != nil { |
| 61 | + fmt.Fprintf(c.ErrOrStderr(), "failed to connect: %v\n", err) |
| 62 | + return |
| 63 | + } |
| 64 | + defer conn.Close() |
| 65 | + client := askpass.NewAskPassServiceClient(conn) |
| 66 | + creds, err := client.GetCredentials(ctx, &askpass.CredentialsRequest{Nonce: nonce}) |
| 67 | + if err != nil { |
| 68 | + fmt.Fprintf(c.ErrOrStderr(), "failed to get credentials: %v\n", err) |
| 69 | + return |
| 70 | + } |
| 71 | + switch { |
| 72 | + case strings.HasPrefix(args[0], "Username"): |
| 73 | + fmt.Fprintln(c.OutOrStdout(), creds.Username) |
| 74 | + case strings.HasPrefix(args[0], "Password"): |
| 75 | + fmt.Fprintln(c.OutOrStdout(), creds.Password) |
| 76 | + default: |
| 77 | + fmt.Fprintf(c.ErrOrStderr(), "unknown credential type '%s'\n", args[0]) |
| 78 | + } |
| 79 | + } |
| 80 | + return cmd |
| 81 | +} |
| 82 | + |
| 83 | +func TestNewAskPassCommand(t *testing.T) { |
| 84 | + testCases := []struct { |
| 85 | + name string |
| 86 | + args []string |
| 87 | + envNonce string |
| 88 | + expectedOut string |
| 89 | + expectedErr string |
| 90 | + }{ |
| 91 | + {"no arguments", []string{}, "testnonce", "", "expected 1 argument, got 0"}, |
| 92 | + {"missing nonce", []string{"Username"}, "", "", fmt.Sprintf("%s is not set", git.ASKPASS_NONCE_ENV)}, |
| 93 | + {"valid username request", []string{"Username"}, "testnonce", "testuser", ""}, |
| 94 | + {"valid password request", []string{"Password"}, "testnonce", "testpassword", ""}, |
| 95 | + {"unknown credential type", []string{"Unknown"}, "testnonce", "", "unknown credential type 'Unknown'"}, |
| 96 | + } |
| 97 | + |
| 98 | + for _, tc := range testCases { |
| 99 | + t.Run(tc.name, func(t *testing.T) { |
| 100 | + os.Clearenv() |
| 101 | + if tc.envNonce != "" { |
| 102 | + os.Setenv(git.ASKPASS_NONCE_ENV, tc.envNonce) |
| 103 | + } |
| 104 | + |
| 105 | + var stdout, stderr bytes.Buffer |
| 106 | + command := NewTestCommand() |
| 107 | + command.SetArgs(tc.args) |
| 108 | + command.SetOut(&stdout) |
| 109 | + command.SetErr(&stderr) |
| 110 | + |
| 111 | + err := command.Execute() |
| 112 | + |
| 113 | + if tc.expectedOut != "" { |
| 114 | + assert.Equal(t, tc.expectedOut, strings.TrimSpace(stdout.String())) |
| 115 | + } |
| 116 | + |
| 117 | + if tc.expectedErr != "" { |
| 118 | + assert.Contains(t, stderr.String(), tc.expectedErr) |
| 119 | + } else { |
| 120 | + assert.NoError(t, err) |
| 121 | + } |
| 122 | + }) |
| 123 | + } |
| 124 | +} |
0 commit comments