Skip to content

Commit 8b2bdca

Browse files
authored
Add 1.22 mux support (#570)
1 parent 98b105c commit 8b2bdca

File tree

61 files changed

+270
-173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+270
-173
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ node_modules/
2424
dist/
2525
generated/
2626
.vendor/
27+
vendor
2728
*.swp
2829
.vscode/launch.json
2930
.vscode/settings.json

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
2323
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
2424
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
2525
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
26+
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
2627
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
2728
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
2829
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=

gothic/gothic.go

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@ import (
1616
"errors"
1717
"fmt"
1818
"io"
19-
"io/ioutil"
2019
"net/http"
2120
"net/url"
2221
"os"
2322
"strings"
2423

25-
"github.com/go-chi/chi/v5"
26-
"github.com/gorilla/mux"
2724
"github.com/gorilla/sessions"
2825
"github.com/markbates/goth"
2926
)
@@ -254,60 +251,6 @@ func Logout(res http.ResponseWriter, req *http.Request) error {
254251
return nil
255252
}
256253

257-
// GetProviderName is a function used to get the name of a provider
258-
// for a given request. By default, this provider is fetched from
259-
// the URL query string. If you provide it in a different way,
260-
// assign your own function to this variable that returns the provider
261-
// name for your request.
262-
var GetProviderName = getProviderName
263-
264-
func getProviderName(req *http.Request) (string, error) {
265-
266-
// try to get it from the url param "provider"
267-
if p := req.URL.Query().Get("provider"); p != "" {
268-
return p, nil
269-
}
270-
271-
// try to get it from the url param ":provider"
272-
if p := req.URL.Query().Get(":provider"); p != "" {
273-
return p, nil
274-
}
275-
276-
// try to get it from the context's value of "provider" key
277-
if p, ok := mux.Vars(req)["provider"]; ok {
278-
return p, nil
279-
}
280-
281-
// try to get it from the go-context's value of "provider" key
282-
if p, ok := req.Context().Value("provider").(string); ok {
283-
return p, nil
284-
}
285-
286-
// try to get it from the url param "provider", when req is routed through 'chi'
287-
if p := chi.URLParam(req, "provider"); p != "" {
288-
return p, nil
289-
}
290-
291-
// try to get it from the go-context's value of providerContextKey key
292-
if p, ok := req.Context().Value(ProviderParamKey).(string); ok {
293-
return p, nil
294-
}
295-
296-
// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
297-
providers := goth.GetProviders()
298-
session, _ := Store.Get(req, SessionName)
299-
for _, provider := range providers {
300-
p := provider.Name()
301-
value := session.Values[p]
302-
if _, ok := value.(string); ok {
303-
return p, nil
304-
}
305-
}
306-
307-
// if not found then return an empty string with the corresponding error
308-
return "", errors.New("you must select a provider")
309-
}
310-
311254
// GetContextWithProvider returns a new request context containing the provider
312255
func GetContextWithProvider(req *http.Request, provider string) *http.Request {
313256
return req.WithContext(context.WithValue(req.Context(), ProviderParamKey, provider))
@@ -347,7 +290,7 @@ func getSessionValue(session *sessions.Session, key string) (string, error) {
347290
if err != nil {
348291
return "", err
349292
}
350-
s, err := ioutil.ReadAll(r)
293+
s, err := io.ReadAll(r)
351294
if err != nil {
352295
return "", err
353296
}

gothic/gothic_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"compress/gzip"
66
"fmt"
77
"html"
8-
"io/ioutil"
8+
"io"
99
"net/http"
1010
"net/http/httptest"
1111
"net/url"
@@ -283,7 +283,7 @@ func ungzipString(value string) string {
283283
if err != nil {
284284
return "err"
285285
}
286-
s, err := ioutil.ReadAll(r)
286+
s, err := io.ReadAll(r)
287287
if err != nil {
288288
return "err"
289289
}

gothic/provider.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//go:build go1.22
2+
// +build go1.22
3+
4+
package gothic
5+
6+
import (
7+
"errors"
8+
"net/http"
9+
10+
"github.com/go-chi/chi/v5"
11+
"github.com/gorilla/mux"
12+
"github.com/markbates/goth"
13+
)
14+
15+
// GetProviderName is a function used to get the name of a provider
16+
// for a given request. By default, this provider is fetched from
17+
// the URL query string. If you provide it in a different way,
18+
// assign your own function to this variable that returns the provider
19+
// name for your request.
20+
var GetProviderName = getProviderName
21+
22+
func getProviderName(req *http.Request) (string, error) {
23+
// try to get it from the url param "provider"
24+
if p := req.URL.Query().Get("provider"); p != "" {
25+
return p, nil
26+
}
27+
28+
// try to get it from the url param ":provider"
29+
if p := req.URL.Query().Get(":provider"); p != "" {
30+
return p, nil
31+
}
32+
33+
// try to get it from the context's value of "provider" key
34+
if p, ok := mux.Vars(req)["provider"]; ok {
35+
return p, nil
36+
}
37+
38+
// try to get it from the go-context's value of "provider" key
39+
if p, ok := req.Context().Value("provider").(string); ok {
40+
return p, nil
41+
}
42+
43+
// try to get it from the url param "provider", when req is routed through 'chi'
44+
if p := chi.URLParam(req, "provider"); p != "" {
45+
return p, nil
46+
}
47+
48+
// try to get it from the route param for go >= 1.22
49+
if p := req.PathValue("provider"); p != "" {
50+
return p, nil
51+
}
52+
53+
// try to get it from the go-context's value of providerContextKey key
54+
if p, ok := req.Context().Value(ProviderParamKey).(string); ok {
55+
return p, nil
56+
}
57+
58+
// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
59+
providers := goth.GetProviders()
60+
session, _ := Store.Get(req, SessionName)
61+
for _, provider := range providers {
62+
p := provider.Name()
63+
value := session.Values[p]
64+
if _, ok := value.(string); ok {
65+
return p, nil
66+
}
67+
}
68+
69+
// if not found then return an empty string with the corresponding error
70+
return "", errors.New("you must select a provider")
71+
}

gothic/provider_legacy.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//go:build !go1.22
2+
// +build !go1.22
3+
4+
package gothic
5+
6+
import (
7+
"errors"
8+
"net/http"
9+
10+
"github.com/go-chi/chi/v5"
11+
"github.com/gorilla/mux"
12+
"github.com/markbates/goth"
13+
)
14+
15+
// GetProviderName is a function used to get the name of a provider
16+
// for a given request. By default, this provider is fetched from
17+
// the URL query string. If you provide it in a different way,
18+
// assign your own function to this variable that returns the provider
19+
// name for your request.
20+
var GetProviderName = getProviderName
21+
22+
func getProviderName(req *http.Request) (string, error) {
23+
// try to get it from the url param "provider"
24+
if p := req.URL.Query().Get("provider"); p != "" {
25+
return p, nil
26+
}
27+
28+
// try to get it from the url param ":provider"
29+
if p := req.URL.Query().Get(":provider"); p != "" {
30+
return p, nil
31+
}
32+
33+
// try to get it from the context's value of "provider" key
34+
if p, ok := mux.Vars(req)["provider"]; ok {
35+
return p, nil
36+
}
37+
38+
// try to get it from the go-context's value of "provider" key
39+
if p, ok := req.Context().Value("provider").(string); ok {
40+
return p, nil
41+
}
42+
43+
// try to get it from the url param "provider", when req is routed through 'chi'
44+
if p := chi.URLParam(req, "provider"); p != "" {
45+
return p, nil
46+
}
47+
48+
// try to get it from the go-context's value of providerContextKey key
49+
if p, ok := req.Context().Value(ProviderParamKey).(string); ok {
50+
return p, nil
51+
}
52+
53+
// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
54+
providers := goth.GetProviders()
55+
session, _ := Store.Get(req, SessionName)
56+
for _, provider := range providers {
57+
p := provider.Name()
58+
value := session.Values[p]
59+
if _, ok := value.(string); ok {
60+
return p, nil
61+
}
62+
}
63+
64+
// if not found then return an empty string with the corresponding error
65+
return "", errors.New("you must select a provider")
66+
}

gothic/provider_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//go:build go1.22
2+
// +build go1.22
3+
4+
package gothic_test
5+
6+
import (
7+
"net/http"
8+
"net/http/httptest"
9+
"net/url"
10+
"testing"
11+
12+
"github.com/markbates/goth/gothic"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
func Test_GetAuthURL122(t *testing.T) {
17+
a := assert.New(t)
18+
19+
res := httptest.NewRecorder()
20+
req, err := http.NewRequest("GET", "/auth", nil)
21+
a.NoError(err)
22+
req.SetPathValue("provider", "faux")
23+
24+
u, err := gothic.GetAuthURL(res, req)
25+
a.NoError(err)
26+
27+
// Check that we get the correct auth URL with a state parameter
28+
parsed, err := url.Parse(u)
29+
a.NoError(err)
30+
a.Equal("http", parsed.Scheme)
31+
a.Equal("example.com", parsed.Host)
32+
q := parsed.Query()
33+
a.Contains(q, "client_id")
34+
a.Equal("code", q.Get("response_type"))
35+
a.NotZero(q, "state")
36+
37+
// Check that if we run GetAuthURL on another request, that request's
38+
// auth URL has a different state from the previous one.
39+
req2, err := http.NewRequest("GET", "/auth?provider=faux", nil)
40+
a.NoError(err)
41+
req2.SetPathValue("provider", "faux")
42+
url2, err := gothic.GetAuthURL(httptest.NewRecorder(), req2)
43+
a.NoError(err)
44+
parsed2, err := url.Parse(url2)
45+
a.NoError(err)
46+
a.NotEqual(parsed.Query().Get("state"), parsed2.Query().Get("state"))
47+
}

providers/amazon/amazon.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"encoding/json"
88
"fmt"
99
"io"
10-
"io/ioutil"
1110
"net/http"
1211
"net/url"
1312

@@ -95,7 +94,7 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
9594
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
9695
}
9796

98-
bits, err := ioutil.ReadAll(response.Body)
97+
bits, err := io.ReadAll(response.Body)
9998
if err != nil {
10099
return user, err
101100
}

providers/azureadv2/azureadv2.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/json"
55
"fmt"
66
"io"
7-
"io/ioutil"
87
"net/http"
98

109
"github.com/markbates/goth"
@@ -199,7 +198,7 @@ func userFromReader(r io.Reader, user *goth.User) error {
199198
UserPrincipalName string `json:"userPrincipalName"` // The user's principal name.
200199
}{}
201200

202-
userBytes, err := ioutil.ReadAll(r)
201+
userBytes, err := io.ReadAll(r)
203202
if err != nil {
204203
return err
205204
}

providers/battlenet/battlenet.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"bytes"
77
"encoding/json"
88
"fmt"
9-
"io/ioutil"
9+
"io"
1010
"net/http"
1111

1212
"github.com/markbates/goth"
@@ -103,7 +103,7 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
103103
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
104104
}
105105

106-
bits, err := ioutil.ReadAll(response.Body)
106+
bits, err := io.ReadAll(response.Body)
107107
if err != nil {
108108
return user, err
109109
}

0 commit comments

Comments
 (0)