Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions packages/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const (
operationCallGetPamSessionKey = "CallGetPamSessionKey"
operationCallUploadPamSessionLog = "CallUploadPamSessionLog"
operationCallPAMSessionTermination = "CallPAMSessionTermination"
operationCallGetMFASessionStatus = "CallGetMFASessionStatus"
operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat"
operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat"
operationCallIssueCertificate = "CallIssueCertificate"
Expand Down Expand Up @@ -1000,6 +1001,25 @@ func CallPAMSessionTermination(httpClient *resty.Client, sessionId string) error
return nil
}

func CallGetMFASessionStatus(httpClient *resty.Client, mfaSessionId string) (MFASessionStatusResponse, error) {
var mfaSessionStatusResponse MFASessionStatusResponse
response, err := httpClient.
R().
SetResult(&mfaSessionStatusResponse).
SetHeader("User-Agent", USER_AGENT).
Get(fmt.Sprintf("%v/v2/mfa-sessions/%s/status", config.INFISICAL_URL, mfaSessionId))

if err != nil {
return MFASessionStatusResponse{}, NewGenericRequestError(operationCallGetMFASessionStatus, err)
}

if response.IsError() {
return MFASessionStatusResponse{}, NewAPIErrorWithResponse(operationCallGetMFASessionStatus, response, nil)
}

return mfaSessionStatusResponse, nil
}

func CallIssueCertificate(httpClient *resty.Client, request IssueCertificateRequest) (*CertificateResponse, error) {
var resBody CertificateResponse
response, err := httpClient.
Expand Down
25 changes: 14 additions & 11 deletions packages/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func NewGenericRequestError(operation string, err error) *GenericRequestError {

// APIError represents an error response from the API
type APIError struct {
Name string `json:"name"`
AdditionalContext string `json:"additionalContext,omitempty"`
ExtraMessages []string `json:"-"`
Details any `json:"details,omitempty"`
Expand Down Expand Up @@ -79,7 +80,7 @@ func (e APIError) Error() string {
}

func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalContext *string) error {
errorMessage, details := TryParseErrorBody(res)
errorMessage, details, errorName := TryParseErrorBody(res)
reqId := util.TryExtractReqId(res)

if res == nil {
Expand All @@ -88,6 +89,7 @@ func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalCo

apiError := &APIError{
Operation: operation,
Name: errorName,
Method: res.Request.Method,
URL: res.Request.URL,
StatusCode: res.StatusCode(),
Expand Down Expand Up @@ -134,26 +136,27 @@ type errorResponse struct {
Message string `json:"message"`
Details any `json:"details"`
ReqId string `json:"reqId"`
Name string `json:"error"`
}

/*
Instead of changing the signature of the sdk function - let's just keep a one local to this codebase
*/
func TryParseErrorBody(res *resty.Response) (string, any) {
func TryParseErrorBody(res *resty.Response) (string, any, string) {
var details any

if res == nil || !res.IsError() {
return "", details
return "", details, ""
}

body := res.String()
if body == "" {
return "", details
return "", details, ""
}

// stringify zod body entirely
if res.StatusCode() == 422 {
return body, details
return body, details, ""
}

// now we have a string, we need to try to parse it as json
Expand All @@ -162,30 +165,30 @@ func TryParseErrorBody(res *resty.Response) (string, any) {
err := json.Unmarshal([]byte(body), &errorResponse)

if err != nil {
return "", details
return "", details, ""
}

// Check if details is empty and return nil if so
if errorResponse.Details != nil {
switch v := errorResponse.Details.(type) {
case []any:
if len(v) == 0 {
return errorResponse.Message, nil
return errorResponse.Message, nil, errorResponse.Name
}
case []string:
if len(v) == 0 {
return errorResponse.Message, nil
return errorResponse.Message, nil, errorResponse.Name
}
case map[string]any:
if len(v) == 0 {
return errorResponse.Message, nil
return errorResponse.Message, nil, errorResponse.Name
}
case string:
if v == "" {
return errorResponse.Message, nil
return errorResponse.Message, nil, errorResponse.Name
}
}
}

return errorResponse.Message, errorResponse.Details
return errorResponse.Message, errorResponse.Details, errorResponse.Name
}
19 changes: 16 additions & 3 deletions packages/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,10 @@ type RegisterGatewayResponse struct {
}

type PAMAccessRequest struct {
Duration string `json:"duration,omitempty"`
AccountPath string `json:"accountPath,omitempty"`
ProjectId string `json:"projectId,omitempty"`
Duration string `json:"duration,omitempty"`
AccountPath string `json:"accountPath,omitempty"`
ProjectId string `json:"projectId,omitempty"`
MfaSessionId string `json:"mfaSessionId,omitempty"`
}

type PAMAccessResponse struct {
Expand Down Expand Up @@ -843,6 +844,18 @@ type PAMSessionCredentials struct {
ServiceAccountToken string `json:"serviceAccountToken,omitempty"`
}

type MFASessionStatus string

const (
MFASessionStatusPending MFASessionStatus = "PENDING"
MFASessionStatusActive MFASessionStatus = "ACTIVE"
)

type MFASessionStatusResponse struct {
Status MFASessionStatus `json:"status"`
MfaMethod string `json:"mfaMethod"`
}

type UploadSessionLogEntry struct {
Timestamp time.Time `json:"timestamp"`
Input string `json:"input"`
Expand Down
42 changes: 42 additions & 0 deletions packages/pam/local/base-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/Infisical/infisical-merge/packages/api"
"github.com/Infisical/infisical-merge/packages/config"
"github.com/Infisical/infisical-merge/packages/pam"
"github.com/Infisical/infisical-merge/packages/util"
"github.com/go-resty/resty/v2"
Expand Down Expand Up @@ -245,3 +246,44 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) {
log.Warn().Msg("Timeout waiting for connections to close, forcing shutdown")
}
}

// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required
// This is a shared function used by both database and SSH proxies
func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequest) (api.PAMAccessResponse, error) {
// Initial request
pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
if err != nil {
// Check if MFA is required
if apiErr, ok := err.(*api.APIError); ok {
if apiErr.Name == "SESSION_MFA_REQUIRED" {
// Extract MFA details from error
if details, ok := apiErr.Details.(map[string]interface{}); ok {
mfaSessionId, _ := details["mfaSessionId"].(string)
mfaMethod, _ := details["mfaMethod"].(string)

if mfaSessionId != "" {
// Handle MFA flow
err := util.HandleMFASession(httpClient, mfaSessionId, mfaMethod, config.INFISICAL_URL)
if err != nil {
return api.PAMAccessResponse{}, fmt.Errorf("MFA verification failed: %w", err)
}

// Retry request with MFA session ID
log.Debug().Msg("Retrying PAM access with MFA session...")
pamRequest.MfaSessionId = mfaSessionId
pamResponse, err = api.CallPAMAccess(httpClient, pamRequest)
if err != nil {
return api.PAMAccessResponse{}, fmt.Errorf("failed to access PAM account after MFA: %w", err)
}

return pamResponse, nil
}
}
}
}
// Return original error if not MFA-related
return api.PAMAccessResponse{}, err
}

return pamResponse, nil
}
2 changes: 1 addition & 1 deletion packages/pam/local/database-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func StartDatabaseLocalProxy(accessToken string, accountPath string, projectID s
ProjectId: projectID,
}

pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
if err != nil {
var apiErr *api.APIError
if errors.As(err, &apiErr) && apiErr.ErrorMessage == "A policy is in place for this resource" {
Expand Down
2 changes: 1 addition & 1 deletion packages/pam/local/ssh-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func StartSSHLocalProxy(accessToken string, accountPath string, projectID string
ProjectId: projectID,
}

pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
if err != nil {
util.HandleError(err, "Failed to access PAM account")
return
Expand Down
50 changes: 50 additions & 0 deletions packages/util/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/Infisical/infisical-merge/packages/api"
"github.com/Infisical/infisical-merge/packages/models"
"github.com/go-resty/resty/v2"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -558,11 +559,59 @@ func GenerateETagFromSecrets(secrets []models.SingleEnvironmentVariable) string

func IsDevelopmentMode() bool {
return CLI_VERSION == "devel"

}

// HandleMFASession opens a browser for MFA verification and polls until completion
func HandleMFASession(httpClient *resty.Client, mfaSessionId string, mfaMethod string, infisicalURL string) error {
// Construct MFA URL
mfaURL := fmt.Sprintf("%s/mfa-session/%s", strings.TrimSuffix(infisicalURL, "/api"), mfaSessionId)

// Display MFA message
fmt.Printf("\n🔐 MFA Verification Required (%s)\n", mfaMethod)
fmt.Printf("→ %s\n", mfaURL)

// Try to open browser
if err := OpenBrowser(mfaURL); err != nil {
log.Debug().Err(err).Msg("Failed to open browser automatically")
} else {
fmt.Println("✓ Browser opened automatically")
}

fmt.Println("⏳ Waiting for MFA verification...\n")

// Poll for MFA completion
maxAttempts := 150 // 5 minutes at 2s intervals
pollInterval := 2 * time.Second

for i := 0; i < maxAttempts; i++ {
time.Sleep(pollInterval)

status, err := api.CallGetMFASessionStatus(httpClient, mfaSessionId)
if err != nil {
// Check if it's a 404 (session expired)
if apiErr, ok := err.(*api.APIError); ok {
if apiErr.StatusCode == 404 {
return fmt.Errorf("MFA session expired. Please try again")
}
}
// Continue polling on other errors
log.Debug().Err(err).Msg("Error polling MFA status, will retry")
continue
}

if status.Status == api.MFASessionStatusActive {
return nil
}
}

return fmt.Errorf("MFA verification timeout. Please try again")
}

// OpenBrowser attempts to open a URL in the user's default browser
func OpenBrowser(url string) error {
var cmd *exec.Cmd

switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
Expand All @@ -571,5 +620,6 @@ func OpenBrowser(url string) error {
default: // linux and others
cmd = exec.Command("xdg-open", url)
}

return cmd.Start()
}