diff --git a/util/api_key_helper.go b/util/api_key_helper.go index 465709f..3c8202c 100644 --- a/util/api_key_helper.go +++ b/util/api_key_helper.go @@ -1,17 +1,13 @@ package util import ( - "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "os" - "os/exec" "path/filepath" - "strings" - "syscall" "time" ) @@ -127,91 +123,16 @@ func needsRefresh(cache *apiKeyCache, refreshInterval time.Duration) bool { } // GetAPIKeyFromHelper executes a shell command to dynamically generate an API key. -// The command is executed in /bin/sh with a timeout controlled by the provided context. +// Platform-specific implementations are in api_key_helper_unix.go and api_key_helper_windows.go. +// +// The command is executed with a timeout controlled by the provided context. // It returns the trimmed output from stdout, or an error if the command fails. // -// On timeout, it kills the entire process group (shell and all descendants) using -// a two-phase approach: SIGTERM for graceful termination, then SIGKILL if needed. +// On timeout: +// - Unix/Linux/macOS: kills the entire process group (shell and all descendants) +// - Windows: terminates the Job Object (cmd.exe and all descendants) // // Security note: The returned API key is sensitive and should not be logged. -func GetAPIKeyFromHelper(ctx context.Context, helperCmd string) (string, error) { - if helperCmd == "" { - return "", fmt.Errorf("api_key_helper command is empty") - } - - // Create context with timeout if not already set - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, HelperTimeout) - defer cancel() - } - - // Execute command in /bin/sh - cmd := exec.CommandContext(ctx, "/bin/sh", "-c", helperCmd) - - // Create a new process group so we can kill all descendants on timeout - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, - } - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - // Start the command - if err := cmd.Start(); err != nil { - return "", fmt.Errorf("api_key_helper start failed: %w", err) - } - - // Wait for command completion in a goroutine - done := make(chan error, 1) - go func() { - // Always Wait to avoid zombie processes - done <- cmd.Wait() - }() - - select { - case err := <-done: - // Command completed normally - if err != nil { - // Don't include stderr in error message as it might contain sensitive info - return "", fmt.Errorf("api_key_helper command failed: %w", err) - } - apiKey := strings.TrimSpace(stdout.String()) - if apiKey == "" { - return "", fmt.Errorf("api_key_helper command returned empty output") - } - return apiKey, nil - - case <-ctx.Done(): - // Timeout or cancellation: terminate the process group gracefully, then forcefully - pgid := cmd.Process.Pid - - // First attempt: send SIGTERM to the entire process group for graceful shutdown - _ = syscall.Kill(-pgid, syscall.SIGTERM) - - // Wait for graceful termination with a grace period - select { - case err := <-done: - if err != nil { - return "", fmt.Errorf("api_key_helper terminated after timeout: %w", err) - } - apiKey := strings.TrimSpace(stdout.String()) - if apiKey == "" { - return "", fmt.Errorf( - "api_key_helper command returned empty output after timeout termination", - ) - } - return apiKey, nil - - case <-time.After(2 * time.Second): - // Grace period expired: send SIGKILL to force termination - _ = syscall.Kill(-pgid, syscall.SIGKILL) - <-done // Wait for cleanup - return "", fmt.Errorf("api_key_helper command timed out after %v", HelperTimeout) - } - } -} // GetAPIKeyFromHelperWithCache executes a shell command to dynamically generate an API key, // with file-based caching support. The API key is cached for the specified refresh interval. diff --git a/util/api_key_helper_unix.go b/util/api_key_helper_unix.go new file mode 100644 index 0000000..3cec66f --- /dev/null +++ b/util/api_key_helper_unix.go @@ -0,0 +1,98 @@ +//go:build !windows + +package util + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "syscall" + "time" +) + +// GetAPIKeyFromHelper executes a shell command to dynamically generate an API key. +// The command is executed in /bin/sh with a timeout controlled by the provided context. +// It returns the trimmed output from stdout, or an error if the command fails. +// +// On timeout, it kills the entire process group (shell and all descendants) using +// a two-phase approach: SIGTERM for graceful termination, then SIGKILL if needed. +// +// Security note: The returned API key is sensitive and should not be logged. +func GetAPIKeyFromHelper(ctx context.Context, helperCmd string) (string, error) { + if helperCmd == "" { + return "", fmt.Errorf("api_key_helper command is empty") + } + + // Create context with timeout if not already set + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, HelperTimeout) + defer cancel() + } + + // Execute command in /bin/sh + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", helperCmd) + + // Create a new process group so we can kill all descendants on timeout + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Start the command + if err := cmd.Start(); err != nil { + return "", fmt.Errorf("api_key_helper start failed: %w", err) + } + + // Wait for command completion in a goroutine + done := make(chan error, 1) + go func() { + // Always Wait to avoid zombie processes + done <- cmd.Wait() + }() + + select { + case err := <-done: + // Command completed normally + if err != nil { + // Don't include stderr in error message as it might contain sensitive info + return "", fmt.Errorf("api_key_helper command failed: %w", err) + } + apiKey := strings.TrimSpace(stdout.String()) + if apiKey == "" { + return "", fmt.Errorf("api_key_helper command returned empty output") + } + return apiKey, nil + + case <-ctx.Done(): + // Timeout or cancellation: terminate the process group gracefully, then forcefully + if cmd.Process == nil { + // Process handle not initialized; wait for cleanup and report timeout + <-done + return "", fmt.Errorf("api_key_helper command timeout after %v", HelperTimeout) + } + pgid := cmd.Process.Pid + + // First attempt: send SIGTERM to the entire process group for graceful shutdown + _ = syscall.Kill(-pgid, syscall.SIGTERM) + + // Wait for graceful termination with a grace period + select { + case <-done: + // Process exited after timeout was reached; treat as timeout regardless of exit status. + // We intentionally ignore stdout/stderr here to avoid returning a key after a timeout. + return "", fmt.Errorf("api_key_helper command timeout after %v", HelperTimeout) + + case <-time.After(2 * time.Second): + // Grace period expired: send SIGKILL to force termination + _ = syscall.Kill(-pgid, syscall.SIGKILL) + <-done // Wait for cleanup + return "", fmt.Errorf("api_key_helper command timeout after %v", HelperTimeout) + } + } +} diff --git a/util/api_key_helper_windows.go b/util/api_key_helper_windows.go new file mode 100644 index 0000000..e22cc93 --- /dev/null +++ b/util/api_key_helper_windows.go @@ -0,0 +1,148 @@ +//go:build windows + +package util + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// createKillOnCloseJob creates a Windows Job Object with KILL_ON_JOB_CLOSE flag. +// When the job handle is closed, all processes in the job will be terminated. +func createKillOnCloseJob() (windows.Handle, error) { + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return 0, err + } + + var info windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION + // Enable KILL_ON_JOB_CLOSE flag + info.BasicLimitInformation.LimitFlags = windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + ) + if err != nil { + _ = windows.CloseHandle(job) + return 0, err + } + return job, nil +} + +// assignProcessToJob assigns a process to a Job Object. +// Returns the process handle which should be closed by the caller. +func assignProcessToJob(job windows.Handle, pid int) (windows.Handle, error) { + // Validate PID range to prevent overflow + if pid < 0 || pid > 0x7FFFFFFF { + return 0, fmt.Errorf("invalid process ID: %d", pid) + } + + // Get child process handle (requires PROCESS_ALL_ACCESS) + // #nosec G115 -- PID validated above to prevent overflow + hProc, err := windows.OpenProcess(windows.PROCESS_ALL_ACCESS, false, uint32(pid)) + if err != nil { + return 0, err + } + // Assign to Job + if err = windows.AssignProcessToJobObject(job, hProc); err != nil { + _ = windows.CloseHandle(hProc) + return 0, err + } + return hProc, nil +} + +// GetAPIKeyFromHelper executes a shell command to dynamically generate an API key. +// The command is executed in cmd.exe with a timeout controlled by the provided context. +// It returns the trimmed output from stdout, or an error if the command fails. +// +// On timeout, it terminates the entire Job Object (cmd.exe and all descendants). +// +// Security note: The returned API key is sensitive and should not be logged. +func GetAPIKeyFromHelper(ctx context.Context, helperCmd string) (string, error) { + if helperCmd == "" { + return "", fmt.Errorf("api_key_helper command is empty") + } + + // Create context with timeout if not already set + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, HelperTimeout) + defer cancel() + } + + // Execute command in cmd.exe + cmd := exec.CommandContext(ctx, "cmd.exe", "/c", helperCmd) + + // Use CREATE_NEW_PROCESS_GROUP and CREATE_BREAKAWAY_FROM_JOB flags + // This allows the child process to be assigned to a new Job, + // even if the parent process is already in a Job + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: windows.CREATE_NEW_PROCESS_GROUP | windows.CREATE_BREAKAWAY_FROM_JOB, + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Create Job Object first + job, err := createKillOnCloseJob() + if err != nil { + return "", fmt.Errorf("create job failed: %w", err) + } + // With KILL_ON_JOB_CLOSE, closing the job will kill all processes + defer func() { + _ = windows.CloseHandle(job) + }() + + // Start the child process + if err = cmd.Start(); err != nil { + return "", fmt.Errorf("api_key_helper start failed: %w", err) + } + + // Assign child process to Job + hProc, err := assignProcessToJob(job, cmd.Process.Pid) + if err != nil { + // If unable to breakaway due to policy, fall back to just killing the process + // (but this won't guarantee killing grandchild processes) + _ = cmd.Process.Kill() + _ = cmd.Wait() + return "", fmt.Errorf("assign process to job failed: %w", err) + } + defer func() { + _ = windows.CloseHandle(hProc) + }() + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + // Don't include stderr in error message as it might contain sensitive info + return "", fmt.Errorf("api_key_helper command failed: %w", err) + } + apiKey := strings.TrimSpace(stdout.String()) + if apiKey == "" { + return "", fmt.Errorf("api_key_helper command returned empty output") + } + return apiKey, nil + + case <-ctx.Done(): + // Timeout: terminate the entire Job (all descendants) + _ = windows.TerminateJobObject(job, 1) + <-done // Wait for cleanup + return "", fmt.Errorf("api_key_helper command timeout after %v", HelperTimeout) + } +} diff --git a/util/api_key_helper_windows_test.go b/util/api_key_helper_windows_test.go new file mode 100644 index 0000000..584dfbd --- /dev/null +++ b/util/api_key_helper_windows_test.go @@ -0,0 +1,370 @@ +//go:build windows + +package util + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +func TestCreateKillOnCloseJob(t *testing.T) { + job, err := createKillOnCloseJob() + if err != nil { + t.Fatalf("createKillOnCloseJob() error = %v, want nil", err) + } + defer func() { + _ = windows.CloseHandle(job) + }() + + // Verify that the job handle is valid + if job == 0 { + t.Error("createKillOnCloseJob() returned invalid handle") + } + + // Verify that KILL_ON_JOB_CLOSE flag is set + var info windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION + var returnLength uint32 + err = windows.QueryInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + &returnLength, + ) + if err != nil { + t.Fatalf("QueryInformationJobObject() error = %v", err) + } + + if info.BasicLimitInformation.LimitFlags&windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE == 0 { + t.Error("KILL_ON_JOB_CLOSE flag is not set") + } +} + +func TestAssignProcessToJob_InvalidPID(t *testing.T) { + job, err := createKillOnCloseJob() + if err != nil { + t.Fatalf("createKillOnCloseJob() error = %v", err) + } + defer func() { + _ = windows.CloseHandle(job) + }() + + tests := []struct { + name string + pid int + }{ + { + name: "negative PID", + pid: -1, + }, + { + name: "PID exceeds max", + pid: 0x80000000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := assignProcessToJob(job, tt.pid) + if err == nil { + t.Error("assignProcessToJob() should return error for invalid PID") + } + if !strings.Contains(err.Error(), "invalid process ID") { + t.Errorf("error should mention invalid PID, got: %v", err) + } + }) + } +} + +func TestAssignProcessToJob_NonExistentPID(t *testing.T) { + job, err := createKillOnCloseJob() + if err != nil { + t.Fatalf("createKillOnCloseJob() error = %v", err) + } + defer func() { + _ = windows.CloseHandle(job) + }() + + // Use a PID that likely doesn't exist (but is valid range) + nonExistentPID := 99999 + + _, err = assignProcessToJob(job, nonExistentPID) + if err == nil { + t.Error("assignProcessToJob() should return error for non-existent PID") + } +} + +func TestGetAPIKeyFromHelper_Windows_Success(t *testing.T) { + tests := []struct { + name string + command string + expected string + }{ + { + name: "simple echo command", + command: "echo test-api-key", + expected: "test-api-key", + }, + { + name: "command with whitespace", + command: "echo test-key-with-spaces ", + expected: "test-key-with-spaces", + }, + { + name: "powershell command", + command: `powershell -Command "Write-Output 'ps-key'"`, + expected: "ps-key", + }, + { + name: "set and echo variable", + command: "set KEY=win-key && echo %KEY%", + expected: "win-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := GetAPIKeyFromHelper(context.Background(), tt.command) + if err != nil { + t.Fatalf("GetAPIKeyFromHelper() error = %v, want nil", err) + } + if result != tt.expected { + t.Errorf("GetAPIKeyFromHelper() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetAPIKeyFromHelper_Windows_Timeout(t *testing.T) { + // Use timeout command (Windows specific) + // This will sleep for 15 seconds, which is longer than HelperTimeout (10s) + command := "timeout /t 15 /nobreak >nul" + + start := time.Now() + _, err := GetAPIKeyFromHelper(context.Background(), command) + duration := time.Since(start) + + if err == nil { + t.Fatal("GetAPIKeyFromHelper() should return timeout error") + } + + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("error message should mention timeout, got: %v", err) + } + + // Verify it actually timed out around the expected timeout duration + // Allow up to 2 seconds margin + if duration < HelperTimeout || duration > HelperTimeout+2*time.Second { + t.Errorf("timeout duration = %v, want around %v", duration, HelperTimeout) + } +} + +func TestGetAPIKeyFromHelper_Windows_KillProcessTree(t *testing.T) { + // Test that the Job Object kills the entire process tree + // Create a command that spawns child processes + command := `cmd /c "timeout /t 15 /nobreak >nul & timeout /t 15 /nobreak >nul"` + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + start := time.Now() + _, err := GetAPIKeyFromHelper(ctx, command) + duration := time.Since(start) + + if err == nil { + t.Fatal("GetAPIKeyFromHelper() should return timeout error") + } + + // Should timeout quickly (around 2 seconds, not 15) + if duration > 3*time.Second { + t.Errorf("timeout took too long: %v, expected around 2s", duration) + } +} + +func TestGetAPIKeyFromHelper_Windows_CommandFailure(t *testing.T) { + tests := []struct { + name string + command string + wantErr string + }{ + { + name: "non-existent command", + command: "nonexistentcommand12345", + wantErr: "failed", + }, + { + name: "command with exit code 1", + command: "exit 1", + wantErr: "failed", + }, + { + name: "invalid syntax", + command: "echo %UNDEFINED_VAR && exit 1", + wantErr: "failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := GetAPIKeyFromHelper(context.Background(), tt.command) + if err == nil { + t.Fatal("GetAPIKeyFromHelper() should return error for failed command") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error should contain %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestGetAPIKeyFromHelper_Windows_EmptyOutput(t *testing.T) { + tests := []struct { + name string + command string + }{ + { + name: "command with no output", + command: "rem no output", + }, + { + name: "command outputting only whitespace", + command: "echo ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := GetAPIKeyFromHelper(context.Background(), tt.command) + if err == nil { + t.Fatal("GetAPIKeyFromHelper() with empty output should return error") + } + if !strings.Contains(err.Error(), "empty output") { + t.Errorf("error message should mention empty output, got: %v", err) + } + }) + } +} + +func TestGetAPIKeyFromHelper_Windows_SecurityStderr(t *testing.T) { + // Command that outputs to stderr (sensitive info should not be leaked in error) + command := "echo secret-data 1>&2 && exit 1" + + _, err := GetAPIKeyFromHelper(context.Background(), command) + if err == nil { + t.Fatal("GetAPIKeyFromHelper() should return error when command fails") + } + + // The error message should NOT contain the stderr output (security consideration) + if strings.Contains(err.Error(), "secret-data") { + t.Error("error message should not leak stderr content (security issue)") + } +} + +func TestGetAPIKeyFromHelper_Windows_ComplexCommands(t *testing.T) { + tests := []struct { + name string + command string + expected string + }{ + { + name: "piped commands", + command: "echo my-api-key | findstr api", + expected: "my-api-key", + }, + { + name: "command with variable substitution", + command: "set KEY=test-123 && echo %KEY%", + expected: "test-123", + }, + { + name: "for loop", + command: `for /F %i in ('echo nested-key') do @echo %i`, + expected: "nested-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := GetAPIKeyFromHelper(context.Background(), tt.command) + if err != nil { + t.Fatalf("GetAPIKeyFromHelper() error = %v, want nil", err) + } + if result != tt.expected { + t.Errorf("GetAPIKeyFromHelper() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetAPIKeyFromHelper_Windows_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Start a long-running command + done := make(chan error, 1) + go func() { + _, err := GetAPIKeyFromHelper(ctx, "timeout /t 30 /nobreak >nul") + done <- err + }() + + // Cancel after a short delay + time.Sleep(500 * time.Millisecond) + cancel() + + // Wait for the command to be cancelled + select { + case err := <-done: + if err == nil { + t.Error("GetAPIKeyFromHelper() should return error on context cancellation") + } + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("error should mention timeout, got: %v", err) + } + case <-time.After(5 * time.Second): + t.Error("GetAPIKeyFromHelper() took too long to respond to cancellation") + } +} + +func TestGetAPIKeyFromHelper_Windows_MultipleInvocations(t *testing.T) { + results := make(chan string, 3) + errors := make(chan error, 3) + + for i := 0; i < 3; i++ { + go func(n int) { + result, err := GetAPIKeyFromHelper( + context.Background(), + fmt.Sprintf("echo test-key-%d", n), + ) + if err != nil { + errors <- err + } else { + results <- result + } + }(i) + } + + // Collect results + successCount := 0 + for i := 0; i < 3; i++ { + select { + case result := <-results: + if !strings.HasPrefix(result, "test-key-") { + t.Errorf("unexpected result: %s", result) + } + successCount++ + case err := <-errors: + t.Errorf("unexpected error: %v", err) + case <-time.After(5 * time.Second): + t.Error("timeout waiting for results") + } + } + + if successCount != 3 { + t.Errorf("expected 3 successful invocations, got %d", successCount) + } +}