diff --git a/packages/api/api.go b/packages/api/api.go index c4acd51..d2b454e 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -53,6 +53,7 @@ const ( operationCallGetPamSessionKey = "CallGetPamSessionKey" operationCallUploadPamSessionLog = "CallUploadPamSessionLog" operationCallPAMSessionTermination = "CallPAMSessionTermination" + operationCallGetMFASessionStatus = "CallGetMFASessionStatus" operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat" operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat" operationCallIssueCertificate = "CallIssueCertificate" @@ -1000,6 +1001,25 @@ func CallPAMSessionTermination(httpClient *resty.Client, sessionId string) error return nil } +func CallGetMFASessionStatus(httpClient *resty.Client, mfaSessionId string) (MFASessionStatusResponse, error) { + var mfaSessionStatusResponse MFASessionStatusResponse + response, err := httpClient. + R(). + SetResult(&mfaSessionStatusResponse). + SetHeader("User-Agent", USER_AGENT). + Get(fmt.Sprintf("%v/v2/mfa-sessions/%s/status", config.INFISICAL_URL, mfaSessionId)) + + if err != nil { + return MFASessionStatusResponse{}, NewGenericRequestError(operationCallGetMFASessionStatus, err) + } + + if response.IsError() { + return MFASessionStatusResponse{}, NewAPIErrorWithResponse(operationCallGetMFASessionStatus, response, nil) + } + + return mfaSessionStatusResponse, nil +} + func CallIssueCertificate(httpClient *resty.Client, request IssueCertificateRequest) (*CertificateResponse, error) { var resBody CertificateResponse response, err := httpClient. diff --git a/packages/api/errors.go b/packages/api/errors.go index 96d9ca1..28a369d 100644 --- a/packages/api/errors.go +++ b/packages/api/errors.go @@ -24,6 +24,7 @@ func NewGenericRequestError(operation string, err error) *GenericRequestError { // APIError represents an error response from the API type APIError struct { + Name string `json:"name"` AdditionalContext string `json:"additionalContext,omitempty"` ExtraMessages []string `json:"-"` Details any `json:"details,omitempty"` @@ -79,7 +80,7 @@ func (e APIError) Error() string { } func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalContext *string) error { - errorMessage, details := TryParseErrorBody(res) + errorMessage, details, errorName := TryParseErrorBody(res) reqId := util.TryExtractReqId(res) if res == nil { @@ -88,6 +89,7 @@ func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalCo apiError := &APIError{ Operation: operation, + Name: errorName, Method: res.Request.Method, URL: res.Request.URL, StatusCode: res.StatusCode(), @@ -134,26 +136,27 @@ type errorResponse struct { Message string `json:"message"` Details any `json:"details"` ReqId string `json:"reqId"` + Name string `json:"error"` } /* Instead of changing the signature of the sdk function - let's just keep a one local to this codebase */ -func TryParseErrorBody(res *resty.Response) (string, any) { +func TryParseErrorBody(res *resty.Response) (string, any, string) { var details any if res == nil || !res.IsError() { - return "", details + return "", details, "" } body := res.String() if body == "" { - return "", details + return "", details, "" } // stringify zod body entirely if res.StatusCode() == 422 { - return body, details + return body, details, "" } // now we have a string, we need to try to parse it as json @@ -162,7 +165,7 @@ func TryParseErrorBody(res *resty.Response) (string, any) { err := json.Unmarshal([]byte(body), &errorResponse) if err != nil { - return "", details + return "", details, "" } // Check if details is empty and return nil if so @@ -170,22 +173,22 @@ func TryParseErrorBody(res *resty.Response) (string, any) { switch v := errorResponse.Details.(type) { case []any: if len(v) == 0 { - return errorResponse.Message, nil + return errorResponse.Message, nil, errorResponse.Name } case []string: if len(v) == 0 { - return errorResponse.Message, nil + return errorResponse.Message, nil, errorResponse.Name } case map[string]any: if len(v) == 0 { - return errorResponse.Message, nil + return errorResponse.Message, nil, errorResponse.Name } case string: if v == "" { - return errorResponse.Message, nil + return errorResponse.Message, nil, errorResponse.Name } } } - return errorResponse.Message, errorResponse.Details + return errorResponse.Message, errorResponse.Details, errorResponse.Name } diff --git a/packages/api/model.go b/packages/api/model.go index fafb19a..cdd063e 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -787,9 +787,10 @@ type RegisterGatewayResponse struct { } type PAMAccessRequest struct { - Duration string `json:"duration,omitempty"` - AccountPath string `json:"accountPath,omitempty"` - ProjectId string `json:"projectId,omitempty"` + Duration string `json:"duration,omitempty"` + AccountPath string `json:"accountPath,omitempty"` + ProjectId string `json:"projectId,omitempty"` + MfaSessionId string `json:"mfaSessionId,omitempty"` } type PAMAccessResponse struct { @@ -843,6 +844,18 @@ type PAMSessionCredentials struct { ServiceAccountToken string `json:"serviceAccountToken,omitempty"` } +type MFASessionStatus string + +const ( + MFASessionStatusPending MFASessionStatus = "PENDING" + MFASessionStatusActive MFASessionStatus = "ACTIVE" +) + +type MFASessionStatusResponse struct { + Status MFASessionStatus `json:"status"` + MfaMethod string `json:"mfaMethod"` +} + type UploadSessionLogEntry struct { Timestamp time.Time `json:"timestamp"` Input string `json:"input"` diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go index 3250683..06e705b 100644 --- a/packages/pam/local/base-proxy.go +++ b/packages/pam/local/base-proxy.go @@ -15,6 +15,7 @@ import ( "time" "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/config" "github.com/Infisical/infisical-merge/packages/pam" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" @@ -245,3 +246,44 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) { log.Warn().Msg("Timeout waiting for connections to close, forcing shutdown") } } + +// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required +// This is a shared function used by both database and SSH proxies +func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequest) (api.PAMAccessResponse, error) { + // Initial request + pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) + if err != nil { + // Check if MFA is required + if apiErr, ok := err.(*api.APIError); ok { + if apiErr.Name == "SESSION_MFA_REQUIRED" { + // Extract MFA details from error + if details, ok := apiErr.Details.(map[string]interface{}); ok { + mfaSessionId, _ := details["mfaSessionId"].(string) + mfaMethod, _ := details["mfaMethod"].(string) + + if mfaSessionId != "" { + // Handle MFA flow + err := util.HandleMFASession(httpClient, mfaSessionId, mfaMethod, config.INFISICAL_URL) + if err != nil { + return api.PAMAccessResponse{}, fmt.Errorf("MFA verification failed: %w", err) + } + + // Retry request with MFA session ID + log.Debug().Msg("Retrying PAM access with MFA session...") + pamRequest.MfaSessionId = mfaSessionId + pamResponse, err = api.CallPAMAccess(httpClient, pamRequest) + if err != nil { + return api.PAMAccessResponse{}, fmt.Errorf("failed to access PAM account after MFA: %w", err) + } + + return pamResponse, nil + } + } + } + } + // Return original error if not MFA-related + return api.PAMAccessResponse{}, err + } + + return pamResponse, nil +} diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index e7c49f9..64525b8 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -61,7 +61,7 @@ func StartDatabaseLocalProxy(accessToken string, accountPath string, projectID s ProjectId: projectID, } - pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) if err != nil { var apiErr *api.APIError if errors.As(err, &apiErr) && apiErr.ErrorMessage == "A policy is in place for this resource" { diff --git a/packages/pam/local/ssh-proxy.go b/packages/pam/local/ssh-proxy.go index 086243e..2092c59 100644 --- a/packages/pam/local/ssh-proxy.go +++ b/packages/pam/local/ssh-proxy.go @@ -38,7 +38,7 @@ func StartSSHLocalProxy(accessToken string, accountPath string, projectID string ProjectId: projectID, } - pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) + pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest) if err != nil { util.HandleError(err, "Failed to access PAM account") return diff --git a/packages/util/helper.go b/packages/util/helper.go index 4b1c730..30eebaa 100644 --- a/packages/util/helper.go +++ b/packages/util/helper.go @@ -21,6 +21,7 @@ import ( "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/models" + "github.com/go-resty/resty/v2" "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -558,11 +559,59 @@ func GenerateETagFromSecrets(secrets []models.SingleEnvironmentVariable) string func IsDevelopmentMode() bool { return CLI_VERSION == "devel" + +} + +// HandleMFASession opens a browser for MFA verification and polls until completion +func HandleMFASession(httpClient *resty.Client, mfaSessionId string, mfaMethod string, infisicalURL string) error { + // Construct MFA URL + mfaURL := fmt.Sprintf("%s/mfa-session/%s", strings.TrimSuffix(infisicalURL, "/api"), mfaSessionId) + + // Display MFA message + fmt.Printf("\nšŸ” MFA Verification Required (%s)\n", mfaMethod) + fmt.Printf("→ %s\n", mfaURL) + + // Try to open browser + if err := OpenBrowser(mfaURL); err != nil { + log.Debug().Err(err).Msg("Failed to open browser automatically") + } else { + fmt.Println("āœ“ Browser opened automatically") + } + + fmt.Println("ā³ Waiting for MFA verification...\n") + + // Poll for MFA completion + maxAttempts := 150 // 5 minutes at 2s intervals + pollInterval := 2 * time.Second + + for i := 0; i < maxAttempts; i++ { + time.Sleep(pollInterval) + + status, err := api.CallGetMFASessionStatus(httpClient, mfaSessionId) + if err != nil { + // Check if it's a 404 (session expired) + if apiErr, ok := err.(*api.APIError); ok { + if apiErr.StatusCode == 404 { + return fmt.Errorf("MFA session expired. Please try again") + } + } + // Continue polling on other errors + log.Debug().Err(err).Msg("Error polling MFA status, will retry") + continue + } + + if status.Status == api.MFASessionStatusActive { + return nil + } + } + + return fmt.Errorf("MFA verification timeout. Please try again") } // OpenBrowser attempts to open a URL in the user's default browser func OpenBrowser(url string) error { var cmd *exec.Cmd + switch runtime.GOOS { case "darwin": cmd = exec.Command("open", url) @@ -571,5 +620,6 @@ func OpenBrowser(url string) error { default: // linux and others cmd = exec.Command("xdg-open", url) } + return cmd.Start() }