Skip to content

Commit 9a686d4

Browse files
committed
Fix compilation on Windows because of lack of SIGTSTP
1 parent 0f800c3 commit 9a686d4

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

oidc.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !windows
2+
13
package vaultkv
24

35
import (

oidc_windows.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package vaultkv
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
"os/signal"
11+
"time"
12+
13+
"github.com/hashicorp/cap/util"
14+
"github.com/hashicorp/go-secure-stdlib/base62"
15+
)
16+
17+
// AuthOIDCMetadata is the metadata member set by AuthOIDC
18+
type AuthOIDCMetadata struct {
19+
AuthURL string `json:"auth_url"`
20+
}
21+
22+
// authHalts are the signals we want to interrupt our auth callback on.
23+
// SIGTSTP is omitted for Windows.
24+
var authHalts = []os.Signal{os.Interrupt, os.Kill}
25+
26+
// AuthOIDC is a shorthand for AuthOIDCMount against the default OIDC mountpoint,
27+
// 'OIDC'.
28+
func (v *Client) AuthOIDC(username, password string) (ret *AuthOutput, err error) {
29+
return v.AuthOIDCMount("OIDC")
30+
}
31+
32+
type loginResponse struct {
33+
authOutput *AuthOutput
34+
err error
35+
}
36+
37+
// AuthOIDCMount submits the given username and password to the OIDC auth endpoint
38+
// mounted at the given mountpoint, checking it against existing OIDC auth
39+
// configurations. If auth is successful, then the AuthOutput object is returned,
40+
// and this client's AuthToken is set to the returned token. Given mountpoint is
41+
// relative to /v1/auth.
42+
func (v *Client) AuthOIDCMount(mount string) (ret *AuthOutput, err error) {
43+
// handle ctrl-c while waiting for the callback
44+
sigintCh := make(chan os.Signal, 1)
45+
signal.Notify(sigintCh, authHalts...)
46+
defer signal.Stop(sigintCh)
47+
raw := &authOutputRaw{}
48+
49+
authURL, clientNonce, err := fetchAuthURL(v, mount)
50+
if err != nil {
51+
return nil, err
52+
}
53+
doneCh := make(chan loginResponse)
54+
http.HandleFunc("/oidc/callback", callbackHandler(v, mount, clientNonce, doneCh))
55+
56+
port := "8250"
57+
listenAddress := "localhost"
58+
listener, err := net.Listen("tcp", listenAddress+":"+port)
59+
if err != nil {
60+
return nil, err
61+
}
62+
defer listener.Close()
63+
64+
fmt.Fprintf(os.Stderr, "Complete the login via your OIDC provider. Launching browser to:\n\n %s\n\n\n", authURL)
65+
if err := util.OpenURL(authURL); err != nil {
66+
return nil, fmt.Errorf("failed to launch the browser , err=%w", err)
67+
}
68+
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")
69+
70+
// Start local server
71+
go func() {
72+
err := http.Serve(listener, nil)
73+
if err != nil && err != http.ErrServerClosed {
74+
doneCh <- loginResponse{nil, err}
75+
}
76+
}()
77+
// Wait for either the callback to finish, or a halt signal (e.g., SIGKILL, SIGINT, SIGTSTP) to be received or up to 2 minutes
78+
select {
79+
case s := <-doneCh:
80+
return s.authOutput, s.err
81+
case <-sigintCh:
82+
return nil, errors.New("Interrupted")
83+
case <-time.After(2 * time.Minute):
84+
return nil, errors.New("Timed out waiting for response from provider")
85+
}
86+
87+
ret = raw.toFinal(AuthOIDCMetadata{})
88+
v.AuthToken = ret.ClientToken
89+
return
90+
}
91+
func fetchAuthURL(v *Client, mount string) (string, string, error) {
92+
//var authURL string
93+
94+
clientNonce, err := base62.Random(20)
95+
if err != nil {
96+
return "", "", err
97+
}
98+
99+
callbackPort := "8250"
100+
callbackMethod := "http"
101+
callbackHost := "localhost"
102+
redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
103+
data := map[string]interface{}{
104+
// only default role is supported
105+
//"role": role,
106+
"redirect_uri": redirectURI,
107+
"client_nonce": clientNonce,
108+
}
109+
raw := &authOutputRaw{}
110+
111+
err = v.doRequest(
112+
"POST",
113+
fmt.Sprintf("auth/%s/oidc/auth_url", mount),
114+
data,
115+
&raw,
116+
)
117+
if err != nil {
118+
return "", "", err
119+
}
120+
121+
authUrl := raw.Data["auth_url"].(string)
122+
return authUrl, clientNonce, err
123+
}
124+
func callbackHandler(v *Client, mount string, clientNonce string, doneCh chan<- loginResponse) http.HandlerFunc {
125+
return func(w http.ResponseWriter, req *http.Request) {
126+
var response string
127+
var authOutput *AuthOutput
128+
var err error
129+
defer func() {
130+
w.Write([]byte(response))
131+
doneCh <- loginResponse{authOutput, err}
132+
}()
133+
134+
// TODO: consider checking for method for post for additional auth step if required
135+
raw := &authOutputRaw{}
136+
query := url.Values{}
137+
query.Add("state", req.FormValue("state"))
138+
query.Add("code", req.FormValue("code"))
139+
query.Add("id_token", req.FormValue("id_token"))
140+
query.Add("client_nonce", clientNonce)
141+
err = v.doRequest(
142+
"GET",
143+
fmt.Sprintf("auth/%s/oidc/callback", mount),
144+
query,
145+
&raw,
146+
)
147+
authOutput = &AuthOutput{}
148+
authOutput.ClientToken = raw.Auth.ClientToken
149+
150+
successHtml := `
151+
<!DOCTYPE html>
152+
<html lang="en">
153+
<head>
154+
<meta charset="UTF-8">
155+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
156+
<style>
157+
body {
158+
font-family: Arial, sans-serif;
159+
text-align: center;
160+
padding: 50px;
161+
}
162+
h1 {
163+
color: #4CAF50;
164+
}
165+
p {
166+
color: #333;
167+
}
168+
</style>
169+
</head>
170+
<body>
171+
<h1>Success!</h1>
172+
<p>Your request was successful.</p>
173+
</body>
174+
</html>`
175+
errorHtml := `
176+
<!DOCTYPE html>
177+
<html lang="en">
178+
<head>
179+
<title>500 Internal Server Error</title>
180+
</head>
181+
<body>
182+
<h1>500 Internal Server Error</h1>
183+
<p>Something went wrong on our end. We're working on fixing it, and we'll be back as soon as possible.</p>
184+
<p>In the meantime, please try again later.</p>
185+
</body>
186+
</html>`
187+
if err != nil {
188+
fmt.Println("Error calling back to vault", err.Error())
189+
response = errorHtml
190+
} else {
191+
response = successHtml
192+
}
193+
}
194+
}

0 commit comments

Comments
 (0)