Skip to content

Commit e9e5240

Browse files
authored
Fix issue with id_token unmarshalling (#538)
1 parent ce547e9 commit e9e5240

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

providers/apple/session.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ func (s Session) Marshal() string {
4848

4949
type IDTokenClaims struct {
5050
jwt.StandardClaims
51-
AccessTokenHash string `json:"at_hash"`
52-
AuthTime int `json:"auth_time"`
53-
Email string `json:"email"`
54-
IsPrivateEmail bool `json:"is_private_email,string"`
51+
AccessTokenHash string `json:"at_hash"`
52+
AuthTime int `json:"auth_time"`
53+
Email string `json:"email"`
54+
IsPrivateEmail BoolString `json:"is_private_email"`
5555
}
5656

5757
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
@@ -123,7 +123,7 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,
123123
s.ID = ID{
124124
Sub: idToken.Claims.(*IDTokenClaims).Subject,
125125
Email: idToken.Claims.(*IDTokenClaims).Email,
126-
IsPrivateEmail: idToken.Claims.(*IDTokenClaims).IsPrivateEmail,
126+
IsPrivateEmail: idToken.Claims.(*IDTokenClaims).IsPrivateEmail.Value(),
127127
}
128128
}
129129

@@ -133,3 +133,36 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,
133133
func (s Session) String() string {
134134
return s.Marshal()
135135
}
136+
137+
// BoolString is a type that can be unmarshalled from a JSON field that can be either a boolean or a string.
138+
// It is used to unmarshal some fields in the Apple ID token that can be sent as either boolean or string.
139+
// See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/authenticating_users_with_sign_in_with_apple#3383773
140+
type BoolString struct {
141+
BoolValue bool
142+
StringValue string
143+
IsValidBool bool
144+
}
145+
146+
func (bs *BoolString) UnmarshalJSON(data []byte) error {
147+
var b bool
148+
if err := json.Unmarshal(data, &b); err == nil {
149+
bs.BoolValue = b
150+
bs.IsValidBool = true
151+
return nil
152+
}
153+
154+
var s string
155+
if err := json.Unmarshal(data, &s); err == nil {
156+
bs.StringValue = s
157+
return nil
158+
}
159+
160+
return errors.New("json field can be either boolean or string")
161+
}
162+
163+
func (bs *BoolString) Value() bool {
164+
if bs.IsValidBool {
165+
return bs.BoolValue
166+
}
167+
return bs.StringValue == "true"
168+
}

providers/apple/session_test.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package apple
22

33
import (
4+
"encoding/json"
45
"testing"
56

6-
"github.com/markbates/goth"
77
"github.com/stretchr/testify/assert"
8+
9+
"github.com/markbates/goth"
810
)
911

1012
func Test_Implements_Session(t *testing.T) {
@@ -45,3 +47,45 @@ func Test_String(t *testing.T) {
4547

4648
a.Equal(s.String(), s.Marshal())
4749
}
50+
51+
func TestIDTokenClaimsUnmarshal(t *testing.T) {
52+
t.Parallel()
53+
a := assert.New(t)
54+
55+
cases := []struct {
56+
name string
57+
idToken string
58+
expectedClaims IDTokenClaims
59+
}{
60+
{
61+
name: "'is_private_email' claim is a string",
62+
idToken: `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","sub":"","email":"[email protected]","is_private_email":"true"}`,
63+
expectedClaims: IDTokenClaims{
64+
65+
IsPrivateEmail: BoolString{
66+
StringValue: "true",
67+
},
68+
},
69+
},
70+
{
71+
name: "'is_private_email' claim is a boolean",
72+
idToken: `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","sub":"","email":"[email protected]","is_private_email":true}`,
73+
expectedClaims: IDTokenClaims{
74+
75+
IsPrivateEmail: BoolString{
76+
BoolValue: true,
77+
IsValidBool: true,
78+
},
79+
},
80+
},
81+
}
82+
83+
for _, c := range cases {
84+
t.Run(c.name, func(t *testing.T) {
85+
idTokenClaims := IDTokenClaims{}
86+
err := json.Unmarshal([]byte(c.idToken), &idTokenClaims)
87+
a.NoError(err)
88+
a.Equal(idTokenClaims, c.expectedClaims)
89+
})
90+
}
91+
}

0 commit comments

Comments
 (0)