Skip to content

Commit ef6c303

Browse files
authored
Add reddit provider and session (#523)
* add reddit provider and session * fix: wechat failing test
1 parent 08df1f0 commit ef6c303

File tree

5 files changed

+393
-1
lines changed

5 files changed

+393
-1
lines changed

providers/reddit/reddit.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package reddit
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"github.com/markbates/goth"
7+
"golang.org/x/oauth2"
8+
"io"
9+
"net/http"
10+
"time"
11+
)
12+
13+
const (
14+
authURL = "https://www.reddit.com/api/v1/authorize"
15+
)
16+
17+
type Provider struct {
18+
providerName string
19+
duration string
20+
config oauth2.Config
21+
client http.Client
22+
// TODO: userURL should be a constant
23+
userURL string
24+
}
25+
26+
func New(clientID string, clientSecret string, redirectURI string, duration string, tokenEndpoint string, userURL string, scopes ...string) Provider {
27+
return Provider{
28+
providerName: "reddit",
29+
duration: duration,
30+
config: oauth2.Config{
31+
ClientID: clientID,
32+
ClientSecret: clientSecret,
33+
Endpoint: oauth2.Endpoint{
34+
AuthURL: authURL,
35+
TokenURL: tokenEndpoint,
36+
AuthStyle: 0,
37+
},
38+
RedirectURL: redirectURI,
39+
Scopes: scopes,
40+
},
41+
client: http.Client{},
42+
userURL: userURL,
43+
}
44+
}
45+
46+
func (p *Provider) Name() string {
47+
return p.providerName
48+
}
49+
50+
func (p *Provider) SetName(name string) {
51+
p.providerName = name
52+
}
53+
54+
func (p *Provider) UnmarshalSession(s string) (goth.Session, error) {
55+
session := &Session{}
56+
err := json.Unmarshal([]byte(s), session)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
return session, nil
62+
}
63+
64+
func (p *Provider) Debug(b bool) {}
65+
66+
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
67+
return nil, nil
68+
}
69+
70+
func (p *Provider) RefreshTokenAvailable() bool {
71+
return true
72+
}
73+
74+
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
75+
authCodeOption := oauth2.SetAuthURLParam("duration", p.duration)
76+
return &Session{AuthURL: p.config.AuthCodeURL(state, authCodeOption)}, nil
77+
}
78+
79+
type redditResponse struct {
80+
Id string `json:"id"`
81+
Name string `json:"name"`
82+
}
83+
84+
func (p *Provider) FetchUser(s goth.Session) (goth.User, error) {
85+
session := s.(*Session)
86+
request, err := http.NewRequest("GET", p.userURL, nil)
87+
if err != nil {
88+
return goth.User{}, err
89+
}
90+
91+
bearer := "Bearer " + session.AccessToken
92+
request.Header.Add("Authorization", bearer)
93+
94+
res, err := p.client.Do(request)
95+
if err != nil {
96+
return goth.User{}, err
97+
}
98+
99+
defer res.Body.Close()
100+
101+
if res.StatusCode != http.StatusOK {
102+
if res.StatusCode == http.StatusForbidden {
103+
return goth.User{}, fmt.Errorf("%s responded with a %s because you did not provide the identity scope which is required to fetch user profile", p.providerName, res.Status)
104+
}
105+
return goth.User{}, fmt.Errorf("%s responded with a %d trying to fetch user profile", p.providerName, res.StatusCode)
106+
}
107+
108+
bits, err := io.ReadAll(res.Body)
109+
if err != nil {
110+
return goth.User{}, err
111+
}
112+
113+
var r redditResponse
114+
115+
err = json.Unmarshal(bits, &r)
116+
if err != nil {
117+
return goth.User{}, err
118+
}
119+
120+
gothUser := goth.User{
121+
RawData: nil,
122+
Provider: p.Name(),
123+
Name: r.Name,
124+
UserID: r.Id,
125+
AccessToken: session.AccessToken,
126+
RefreshToken: session.RefreshToken,
127+
ExpiresAt: time.Time{},
128+
}
129+
130+
err = json.Unmarshal(bits, &gothUser.RawData)
131+
if err != nil {
132+
return goth.User{}, err
133+
}
134+
135+
return gothUser, nil
136+
}

providers/reddit/reddit_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package reddit
2+
3+
import (
4+
"encoding/json"
5+
"github.com/markbates/goth"
6+
"golang.org/x/oauth2"
7+
"net/http"
8+
"net/http/httptest"
9+
"reflect"
10+
"testing"
11+
"time"
12+
)
13+
14+
var response = redditResponse{
15+
Id: "invader21",
16+
Name: "JohnDoe",
17+
}
18+
19+
func TestProvider(t *testing.T) {
20+
t.Run("create a new provider", func(t *testing.T) {
21+
got := New("client id", "client secret", "redirect uri", "duration", "example.com", "userURL", "scope1", "scope2", "scope 3")
22+
want := Provider{
23+
providerName: "reddit",
24+
duration: "duration",
25+
config: oauth2.Config{
26+
ClientID: "client id",
27+
ClientSecret: "client secret",
28+
Endpoint: oauth2.Endpoint{
29+
AuthURL: authURL,
30+
TokenURL: "example.com",
31+
AuthStyle: 0,
32+
},
33+
RedirectURL: "redirect uri",
34+
Scopes: []string{"scope1", "scope2", "scope 3"},
35+
},
36+
userURL: "userURL",
37+
}
38+
39+
if !reflect.DeepEqual(got, want) {
40+
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t \033[31;1;4mwant\033[0m %+v", got, want)
41+
}
42+
})
43+
44+
t.Run("fetch reddit user that created the given session", func(t *testing.T) {
45+
redditServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
46+
b, err := json.Marshal(response)
47+
if err != nil {
48+
t.Fatal(err)
49+
}
50+
writer.Header().Add("Content-Type", "application/json")
51+
writer.Write(b)
52+
}))
53+
54+
defer redditServer.Close()
55+
56+
userURL := redditServer.URL
57+
p := New("client id", "client secret", "redirect uri", "duration", "example.com", userURL, "scope1", "scope2", "scope 3")
58+
s := &Session{
59+
AuthURL: "",
60+
AccessToken: "i am a token",
61+
TokenType: "bearer",
62+
RefreshToken: "your refresh token",
63+
Expiry: time.Time{},
64+
}
65+
66+
got, err := p.FetchUser(s)
67+
if err != nil {
68+
t.Errorf("did not expect an error: %s", err)
69+
}
70+
71+
want := goth.User{
72+
RawData: map[string]interface{}{
73+
"id": "invader21",
74+
"name": "JohnDoe",
75+
},
76+
Provider: "reddit",
77+
Name: "JohnDoe",
78+
UserID: "invader21",
79+
AccessToken: "i am a token",
80+
RefreshToken: "your refresh token",
81+
ExpiresAt: time.Time{},
82+
}
83+
84+
if !reflect.DeepEqual(got, want) {
85+
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t\t \033[31;1;4mwant\033[0m %+v", got, want)
86+
}
87+
})
88+
}

providers/reddit/session.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package reddit
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"github.com/markbates/goth"
8+
"golang.org/x/oauth2"
9+
"time"
10+
)
11+
12+
type Session struct {
13+
AuthURL string
14+
AccessToken string `json:"access_token"`
15+
TokenType string `json:"token_type,omitempty"`
16+
RefreshToken string `json:"refresh_token,omitempty"`
17+
Expiry time.Time `json:"expiry,omitempty"`
18+
}
19+
20+
func (s *Session) GetAuthURL() (string, error) {
21+
return s.AuthURL, nil
22+
}
23+
24+
func (s *Session) Marshal() string {
25+
b, _ := json.Marshal(s)
26+
return string(b)
27+
}
28+
29+
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
30+
p := provider.(*Provider)
31+
t, err := p.config.Exchange(context.WithValue(context.Background(), oauth2.HTTPClient, p.client), params.Get("code"))
32+
if err != nil {
33+
return "", err
34+
}
35+
36+
if !t.Valid() {
37+
return "", errors.New("invalid token received from provider")
38+
}
39+
40+
s.AccessToken = t.AccessToken
41+
s.TokenType = t.TokenType
42+
s.RefreshToken = t.RefreshToken
43+
s.Expiry = t.Expiry
44+
45+
return s.AccessToken, nil
46+
}

providers/reddit/session_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package reddit
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"testing"
9+
)
10+
11+
var validAuthResponseTestData = struct {
12+
AccessToken string `json:"access_token"`
13+
TokenType string `json:"token_type"`
14+
ExpiresIn int `json:"expires_in"`
15+
Scope string `json:"scope"`
16+
RefreshToken string `json:"refresh_token"`
17+
}{
18+
AccessToken: "i am a token",
19+
TokenType: "type",
20+
ExpiresIn: 120,
21+
Scope: "identity",
22+
RefreshToken: "your refresh token",
23+
}
24+
25+
var invalidAuthResponseTestData = struct {
26+
AccessToken string `json:"access_token"`
27+
TokenType string `json:"token_type"`
28+
ExpiresIn int `json:"expires_in"`
29+
Scope string `json:"scope"`
30+
RefreshToken string `json:"refresh_token"`
31+
}{
32+
AccessToken: "",
33+
TokenType: "type",
34+
ExpiresIn: 120,
35+
Scope: "identity",
36+
RefreshToken: "Your refresh token",
37+
}
38+
39+
func TestSession(t *testing.T) {
40+
t.Run("gets the URL for the authentication end-point for the provider", func(t *testing.T) {
41+
s := Session{AuthURL: "example.com"}
42+
got, err := s.GetAuthURL()
43+
if err != nil {
44+
t.Fatal("should return a url string")
45+
}
46+
47+
want := "example.com"
48+
49+
if got != want {
50+
t.Errorf("got %q want %q", got, want)
51+
}
52+
})
53+
54+
t.Run("generates a string representation of the session", func(t *testing.T) {
55+
s := Session{
56+
AuthURL: "example",
57+
}
58+
got := s.Marshal()
59+
want := `{"AuthURL":"example","access_token":"","expiry":"0001-01-01T00:00:00Z"}`
60+
61+
if got != want {
62+
t.Errorf("got %q want %q", got, want)
63+
}
64+
})
65+
66+
t.Run("return an access token", func(t *testing.T) {
67+
68+
s := Session{AuthURL: "example.com"}
69+
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
70+
b, err := json.Marshal(validAuthResponseTestData)
71+
if err != nil {
72+
writer.WriteHeader(http.StatusInternalServerError)
73+
return
74+
}
75+
writer.Header().Add("Content-Type", "application/json")
76+
writer.WriteHeader(http.StatusOK)
77+
writer.Write(b)
78+
}))
79+
80+
tokenURL := authServer.URL
81+
82+
p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
83+
u := url.Values{}
84+
u.Set("code", "12345678")
85+
86+
got, err := s.Authorize(&p, u)
87+
if err != nil {
88+
t.Fatal("did not expect an error: ", err)
89+
}
90+
91+
want := validAuthResponseTestData.AccessToken
92+
93+
if got != want {
94+
t.Errorf("got %q want %q", got, want)
95+
}
96+
})
97+
98+
t.Run("validates access token", func(t *testing.T) {
99+
s := Session{AuthURL: "example.com"}
100+
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
101+
b, err := json.Marshal(invalidAuthResponseTestData)
102+
if err != nil {
103+
writer.WriteHeader(http.StatusInternalServerError)
104+
return
105+
}
106+
writer.Header().Add("Content-Type", "application/json")
107+
writer.WriteHeader(http.StatusOK)
108+
writer.Write(b)
109+
}))
110+
111+
tokenURL := authServer.URL
112+
113+
p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
114+
u := url.Values{}
115+
u.Set("code", "12345678")
116+
117+
_, err := s.Authorize(&p, u)
118+
if err == nil {
119+
t.Errorf("expected an error but didn't get one")
120+
}
121+
})
122+
}

0 commit comments

Comments
 (0)