Skip to content

Commit 9154105

Browse files
Merge pull request #61 from Infisical/feat/add-pam-session-gate-handling
feat: add mfa handling for pam access account
2 parents b39e859 + 25c4b80 commit 9154105

File tree

7 files changed

+144
-16
lines changed

7 files changed

+144
-16
lines changed

packages/api/api.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ const (
5353
operationCallGetPamSessionKey = "CallGetPamSessionKey"
5454
operationCallUploadPamSessionLog = "CallUploadPamSessionLog"
5555
operationCallPAMSessionTermination = "CallPAMSessionTermination"
56+
operationCallGetMFASessionStatus = "CallGetMFASessionStatus"
5657
operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat"
5758
operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat"
5859
operationCallIssueCertificate = "CallIssueCertificate"
@@ -1000,6 +1001,25 @@ func CallPAMSessionTermination(httpClient *resty.Client, sessionId string) error
10001001
return nil
10011002
}
10021003

1004+
func CallGetMFASessionStatus(httpClient *resty.Client, mfaSessionId string) (MFASessionStatusResponse, error) {
1005+
var mfaSessionStatusResponse MFASessionStatusResponse
1006+
response, err := httpClient.
1007+
R().
1008+
SetResult(&mfaSessionStatusResponse).
1009+
SetHeader("User-Agent", USER_AGENT).
1010+
Get(fmt.Sprintf("%v/v2/mfa-sessions/%s/status", config.INFISICAL_URL, mfaSessionId))
1011+
1012+
if err != nil {
1013+
return MFASessionStatusResponse{}, NewGenericRequestError(operationCallGetMFASessionStatus, err)
1014+
}
1015+
1016+
if response.IsError() {
1017+
return MFASessionStatusResponse{}, NewAPIErrorWithResponse(operationCallGetMFASessionStatus, response, nil)
1018+
}
1019+
1020+
return mfaSessionStatusResponse, nil
1021+
}
1022+
10031023
func CallIssueCertificate(httpClient *resty.Client, request IssueCertificateRequest) (*CertificateResponse, error) {
10041024
var resBody CertificateResponse
10051025
response, err := httpClient.

packages/api/errors.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func NewGenericRequestError(operation string, err error) *GenericRequestError {
2424

2525
// APIError represents an error response from the API
2626
type APIError struct {
27+
Name string `json:"name"`
2728
AdditionalContext string `json:"additionalContext,omitempty"`
2829
ExtraMessages []string `json:"-"`
2930
Details any `json:"details,omitempty"`
@@ -79,7 +80,7 @@ func (e APIError) Error() string {
7980
}
8081

8182
func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalContext *string) error {
82-
errorMessage, details := TryParseErrorBody(res)
83+
errorMessage, details, errorName := TryParseErrorBody(res)
8384
reqId := util.TryExtractReqId(res)
8485

8586
if res == nil {
@@ -88,6 +89,7 @@ func NewAPIErrorWithResponse(operation string, res *resty.Response, additionalCo
8889

8990
apiError := &APIError{
9091
Operation: operation,
92+
Name: errorName,
9193
Method: res.Request.Method,
9294
URL: res.Request.URL,
9395
StatusCode: res.StatusCode(),
@@ -134,26 +136,27 @@ type errorResponse struct {
134136
Message string `json:"message"`
135137
Details any `json:"details"`
136138
ReqId string `json:"reqId"`
139+
Name string `json:"error"`
137140
}
138141

139142
/*
140143
Instead of changing the signature of the sdk function - let's just keep a one local to this codebase
141144
*/
142-
func TryParseErrorBody(res *resty.Response) (string, any) {
145+
func TryParseErrorBody(res *resty.Response) (string, any, string) {
143146
var details any
144147

145148
if res == nil || !res.IsError() {
146-
return "", details
149+
return "", details, ""
147150
}
148151

149152
body := res.String()
150153
if body == "" {
151-
return "", details
154+
return "", details, ""
152155
}
153156

154157
// stringify zod body entirely
155158
if res.StatusCode() == 422 {
156-
return body, details
159+
return body, details, ""
157160
}
158161

159162
// now we have a string, we need to try to parse it as json
@@ -162,30 +165,30 @@ func TryParseErrorBody(res *resty.Response) (string, any) {
162165
err := json.Unmarshal([]byte(body), &errorResponse)
163166

164167
if err != nil {
165-
return "", details
168+
return "", details, ""
166169
}
167170

168171
// Check if details is empty and return nil if so
169172
if errorResponse.Details != nil {
170173
switch v := errorResponse.Details.(type) {
171174
case []any:
172175
if len(v) == 0 {
173-
return errorResponse.Message, nil
176+
return errorResponse.Message, nil, errorResponse.Name
174177
}
175178
case []string:
176179
if len(v) == 0 {
177-
return errorResponse.Message, nil
180+
return errorResponse.Message, nil, errorResponse.Name
178181
}
179182
case map[string]any:
180183
if len(v) == 0 {
181-
return errorResponse.Message, nil
184+
return errorResponse.Message, nil, errorResponse.Name
182185
}
183186
case string:
184187
if v == "" {
185-
return errorResponse.Message, nil
188+
return errorResponse.Message, nil, errorResponse.Name
186189
}
187190
}
188191
}
189192

190-
return errorResponse.Message, errorResponse.Details
193+
return errorResponse.Message, errorResponse.Details, errorResponse.Name
191194
}

packages/api/model.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,9 +787,10 @@ type RegisterGatewayResponse struct {
787787
}
788788

789789
type PAMAccessRequest struct {
790-
Duration string `json:"duration,omitempty"`
791-
AccountPath string `json:"accountPath,omitempty"`
792-
ProjectId string `json:"projectId,omitempty"`
790+
Duration string `json:"duration,omitempty"`
791+
AccountPath string `json:"accountPath,omitempty"`
792+
ProjectId string `json:"projectId,omitempty"`
793+
MfaSessionId string `json:"mfaSessionId,omitempty"`
793794
}
794795

795796
type PAMAccessResponse struct {
@@ -843,6 +844,18 @@ type PAMSessionCredentials struct {
843844
ServiceAccountToken string `json:"serviceAccountToken,omitempty"`
844845
}
845846

847+
type MFASessionStatus string
848+
849+
const (
850+
MFASessionStatusPending MFASessionStatus = "PENDING"
851+
MFASessionStatusActive MFASessionStatus = "ACTIVE"
852+
)
853+
854+
type MFASessionStatusResponse struct {
855+
Status MFASessionStatus `json:"status"`
856+
MfaMethod string `json:"mfaMethod"`
857+
}
858+
846859
type UploadSessionLogEntry struct {
847860
Timestamp time.Time `json:"timestamp"`
848861
Input string `json:"input"`

packages/pam/local/base-proxy.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"time"
1616

1717
"github.com/Infisical/infisical-merge/packages/api"
18+
"github.com/Infisical/infisical-merge/packages/config"
1819
"github.com/Infisical/infisical-merge/packages/pam"
1920
"github.com/Infisical/infisical-merge/packages/util"
2021
"github.com/go-resty/resty/v2"
@@ -245,3 +246,44 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) {
245246
log.Warn().Msg("Timeout waiting for connections to close, forcing shutdown")
246247
}
247248
}
249+
250+
// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required
251+
// This is a shared function used by both database and SSH proxies
252+
func CallPAMAccessWithMFA(httpClient *resty.Client, pamRequest api.PAMAccessRequest) (api.PAMAccessResponse, error) {
253+
// Initial request
254+
pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
255+
if err != nil {
256+
// Check if MFA is required
257+
if apiErr, ok := err.(*api.APIError); ok {
258+
if apiErr.Name == "SESSION_MFA_REQUIRED" {
259+
// Extract MFA details from error
260+
if details, ok := apiErr.Details.(map[string]interface{}); ok {
261+
mfaSessionId, _ := details["mfaSessionId"].(string)
262+
mfaMethod, _ := details["mfaMethod"].(string)
263+
264+
if mfaSessionId != "" {
265+
// Handle MFA flow
266+
err := util.HandleMFASession(httpClient, mfaSessionId, mfaMethod, config.INFISICAL_URL)
267+
if err != nil {
268+
return api.PAMAccessResponse{}, fmt.Errorf("MFA verification failed: %w", err)
269+
}
270+
271+
// Retry request with MFA session ID
272+
log.Debug().Msg("Retrying PAM access with MFA session...")
273+
pamRequest.MfaSessionId = mfaSessionId
274+
pamResponse, err = api.CallPAMAccess(httpClient, pamRequest)
275+
if err != nil {
276+
return api.PAMAccessResponse{}, fmt.Errorf("failed to access PAM account after MFA: %w", err)
277+
}
278+
279+
return pamResponse, nil
280+
}
281+
}
282+
}
283+
}
284+
// Return original error if not MFA-related
285+
return api.PAMAccessResponse{}, err
286+
}
287+
288+
return pamResponse, nil
289+
}

packages/pam/local/database-proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func StartDatabaseLocalProxy(accessToken string, accountPath string, projectID s
6161
ProjectId: projectID,
6262
}
6363

64-
pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
64+
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
6565
if err != nil {
6666
var apiErr *api.APIError
6767
if errors.As(err, &apiErr) && apiErr.ErrorMessage == "A policy is in place for this resource" {

packages/pam/local/ssh-proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func StartSSHLocalProxy(accessToken string, accountPath string, projectID string
3838
ProjectId: projectID,
3939
}
4040

41-
pamResponse, err := api.CallPAMAccess(httpClient, pamRequest)
41+
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest)
4242
if err != nil {
4343
util.HandleError(err, "Failed to access PAM account")
4444
return

packages/util/helper.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/Infisical/infisical-merge/packages/api"
2323
"github.com/Infisical/infisical-merge/packages/models"
24+
"github.com/go-resty/resty/v2"
2425
"github.com/rs/zerolog/log"
2526
"github.com/spf13/cobra"
2627
)
@@ -558,11 +559,59 @@ func GenerateETagFromSecrets(secrets []models.SingleEnvironmentVariable) string
558559

559560
func IsDevelopmentMode() bool {
560561
return CLI_VERSION == "devel"
562+
563+
}
564+
565+
// HandleMFASession opens a browser for MFA verification and polls until completion
566+
func HandleMFASession(httpClient *resty.Client, mfaSessionId string, mfaMethod string, infisicalURL string) error {
567+
// Construct MFA URL
568+
mfaURL := fmt.Sprintf("%s/mfa-session/%s", strings.TrimSuffix(infisicalURL, "/api"), mfaSessionId)
569+
570+
// Display MFA message
571+
fmt.Printf("\n🔐 MFA Verification Required (%s)\n", mfaMethod)
572+
fmt.Printf("→ %s\n", mfaURL)
573+
574+
// Try to open browser
575+
if err := OpenBrowser(mfaURL); err != nil {
576+
log.Debug().Err(err).Msg("Failed to open browser automatically")
577+
} else {
578+
fmt.Println("✓ Browser opened automatically")
579+
}
580+
581+
fmt.Println("⏳ Waiting for MFA verification...\n")
582+
583+
// Poll for MFA completion
584+
maxAttempts := 150 // 5 minutes at 2s intervals
585+
pollInterval := 2 * time.Second
586+
587+
for i := 0; i < maxAttempts; i++ {
588+
time.Sleep(pollInterval)
589+
590+
status, err := api.CallGetMFASessionStatus(httpClient, mfaSessionId)
591+
if err != nil {
592+
// Check if it's a 404 (session expired)
593+
if apiErr, ok := err.(*api.APIError); ok {
594+
if apiErr.StatusCode == 404 {
595+
return fmt.Errorf("MFA session expired. Please try again")
596+
}
597+
}
598+
// Continue polling on other errors
599+
log.Debug().Err(err).Msg("Error polling MFA status, will retry")
600+
continue
601+
}
602+
603+
if status.Status == api.MFASessionStatusActive {
604+
return nil
605+
}
606+
}
607+
608+
return fmt.Errorf("MFA verification timeout. Please try again")
561609
}
562610

563611
// OpenBrowser attempts to open a URL in the user's default browser
564612
func OpenBrowser(url string) error {
565613
var cmd *exec.Cmd
614+
566615
switch runtime.GOOS {
567616
case "darwin":
568617
cmd = exec.Command("open", url)
@@ -571,5 +620,6 @@ func OpenBrowser(url string) error {
571620
default: // linux and others
572621
cmd = exec.Command("xdg-open", url)
573622
}
623+
574624
return cmd.Start()
575625
}

0 commit comments

Comments
 (0)