Skip to content

Commit 961f227

Browse files
committed
fix: tests
1 parent aaf0831 commit 961f227

File tree

9 files changed

+94
-27
lines changed

9 files changed

+94
-27
lines changed

server/db/providers/cassandradb/provider.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,44 @@ func NewProvider() (*provider, error) {
5858
return nil, err
5959
}
6060

61-
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nick_name text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id, email))", KeySpace, models.Collections.User)
61+
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User)
6262
err = session.Query(userCollectionQuery).Exec()
6363
if err != nil {
6464
log.Println("Unable to create user collection:", err)
6565
return nil, err
6666
}
67+
userIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_user_email ON %s.%s (email)", KeySpace, models.Collections.User)
68+
err = session.Query(userIndexQuery).Exec()
69+
if err != nil {
70+
log.Println("Unable to create user index:", err)
71+
return nil, err
72+
}
6773

6874
// token is reserved keyword in cassandra, hence we need to use jwt_token
69-
verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id, identifier, email))", KeySpace, models.Collections.VerificationRequest)
75+
verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.VerificationRequest)
7076
err = session.Query(verificationRequestCollectionQuery).Exec()
7177
if err != nil {
7278
log.Println("Unable to create verification request collection:", err)
7379
return nil, err
7480
}
81+
verificationRequestIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_email ON %s.%s (email)", KeySpace, models.Collections.VerificationRequest)
82+
err = session.Query(verificationRequestIndexQuery).Exec()
83+
if err != nil {
84+
log.Println("Unable to create verification_requests index:", err)
85+
return nil, err
86+
}
87+
verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_identifier ON %s.%s (identifier)", KeySpace, models.Collections.VerificationRequest)
88+
err = session.Query(verificationRequestIndexQuery).Exec()
89+
if err != nil {
90+
log.Println("Unable to create verification_requests index:", err)
91+
return nil, err
92+
}
93+
verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_jwt_token ON %s.%s (jwt_token)", KeySpace, models.Collections.VerificationRequest)
94+
err = session.Query(verificationRequestIndexQuery).Exec()
95+
if err != nil {
96+
log.Println("Unable to create verification_requests index:", err)
97+
return nil, err
98+
}
7599

76100
return &provider{
77101
db: session,

server/db/providers/cassandradb/user.go

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,39 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
3232
if err != nil {
3333
return user, err
3434
}
35+
36+
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
37+
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
38+
decoder.UseNumber()
3539
userMap := map[string]interface{}{}
36-
json.Unmarshal(bytes, &userMap)
40+
err = decoder.Decode(&userMap)
41+
if err != nil {
42+
return user, err
43+
}
3744

3845
fields := "("
3946
values := "("
4047
for key, value := range userMap {
4148
if value != nil {
42-
fields += key + ","
49+
if key == "_id" {
50+
fields += "id,"
51+
} else {
52+
fields += key + ","
53+
}
4354

4455
valueType := reflect.TypeOf(value)
45-
if valueType.Kind() == reflect.String {
46-
values += "'" + value.(string) + "',"
56+
if valueType.Name() == "string" {
57+
values += fmt.Sprintf("'%s',", value.(string))
4758
} else {
48-
values += fmt.Sprintf("%v", value) + ","
59+
values += fmt.Sprintf("%v,", value)
4960
}
5061
}
5162
}
5263

5364
fields = fields[:len(fields)-1] + ")"
5465
values = values[:len(values)-1] + ")"
5566

56-
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s", KeySpace+"."+models.Collections.User, fields, values)
67+
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values)
5768

5869
err = p.db.Query(query).Exec()
5970
if err != nil {
@@ -66,26 +77,46 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
6677
// UpdateUser to update user information in database
6778
func (p *provider) UpdateUser(user models.User) (models.User, error) {
6879
user.UpdatedAt = time.Now().Unix()
80+
6981
bytes, err := json.Marshal(user)
7082
if err != nil {
7183
return user, err
7284
}
85+
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
86+
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
87+
decoder.UseNumber()
7388
userMap := map[string]interface{}{}
74-
json.Unmarshal(bytes, &userMap)
89+
err = decoder.Decode(&userMap)
90+
if err != nil {
91+
return user, err
92+
}
7593

7694
updateFields := ""
7795
for key, value := range userMap {
78-
if value != nil {
79-
valueType := reflect.TypeOf(value)
80-
if valueType.Kind() == reflect.String {
81-
updateFields += key + " = '" + value.(string) + "',"
82-
} else {
83-
updateFields += key + " = " + fmt.Sprintf("%v", value) + ","
84-
}
96+
if value != nil && key != "_id" {
97+
}
98+
99+
if key == "_id" {
100+
continue
101+
}
102+
103+
if value == nil {
104+
updateFields += fmt.Sprintf("%s = null,", key)
105+
continue
106+
}
107+
108+
valueType := reflect.TypeOf(value)
109+
if valueType.Name() == "string" {
110+
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
111+
} else {
112+
updateFields += fmt.Sprintf("%s = %v, ", key, value)
85113
}
86114
}
115+
updateFields = strings.Trim(updateFields, " ")
116+
updateFields = strings.TrimSuffix(updateFields, ",")
87117

88118
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID)
119+
89120
err = p.db.Query(query).Exec()
90121
if err != nil {
91122
return user, err
@@ -97,8 +128,8 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
97128
// DeleteUser to delete user information from database
98129
func (p *provider) DeleteUser(user models.User) error {
99130
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
100-
101-
return p.db.Query(query).Exec()
131+
err := p.db.Query(query).Exec()
132+
return err
102133
}
103134

104135
// ListUsers to get list of users from database
@@ -114,7 +145,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
114145
// there is no offset in cassandra
115146
// so we fetch till limit + offset
116147
// and return the results from offset to limit
117-
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
148+
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
118149

119150
scanner := p.db.Query(query).Iter().Scanner()
120151
counter := int64(0)

server/db/providers/cassandradb/verification_requests.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cassandradb
22

33
import (
44
"fmt"
5+
"log"
56
"time"
67

78
"github.com/authorizerdev/authorizer/server/db/models"
@@ -31,6 +32,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
3132
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
3233
var verificationRequest models.VerificationRequest
3334
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token)
35+
3436
err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
3537
if err != nil {
3638
return verificationRequest, err
@@ -41,7 +43,8 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
4143
// GetVerificationRequestByEmail to get verification request by email from database
4244
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
4345
var verificationRequest models.VerificationRequest
44-
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier)
46+
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier)
47+
4548
err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
4649
if err != nil {
4750
return verificationRequest, err
@@ -58,20 +61,23 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
5861
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest)
5962
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
6063
if err != nil {
64+
log.Println("Error while quering verification request", err)
6165
return nil, err
6266
}
6367

6468
// there is no offset in cassandra
6569
// so we fetch till limit + offset
6670
// and return the results from offset to limit
67-
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset)
71+
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset)
72+
6873
scanner := p.db.Query(query).Iter().Scanner()
6974
counter := int64(0)
7075
for scanner.Next() {
7176
if counter >= pagination.Offset {
7277
var verificationRequest models.VerificationRequest
7378
err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
7479
if err != nil {
80+
log.Println("Error while parsing verification request", err)
7581
return nil, err
7682
}
7783
verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest())

server/email/invite_email.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error {
107107

108108
err := SendMail(Receiver, Subject, message)
109109
if err != nil {
110-
log.Println("=> error sending email:", err)
110+
log.Println("error sending email:", err)
111111
}
112112
return err
113113
}

server/email/verification_email.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func SendVerificationMail(toEmail, token, hostname string) error {
107107

108108
err := SendMail(Receiver, Subject, message)
109109
if err != nil {
110-
log.Println("=> error sending email:", err)
110+
log.Println("error sending email:", err)
111111
}
112112
return err
113113
}

server/handlers/oauth_callback.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func processGithubUserInfo(code string) (models.User, error) {
259259
GivenName: &firstName,
260260
FamilyName: &lastName,
261261
Picture: &picture,
262-
Email: userRawData["sub"],
262+
Email: userRawData["email"],
263263
}
264264

265265
return user, nil

server/test/enable_access_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import (
1515

1616
func enableAccessTest(t *testing.T, s TestSetup) {
1717
t.Helper()
18-
t.Run(`should revoke access`, func(t *testing.T) {
18+
t.Run(`should enable access`, func(t *testing.T) {
1919
req, ctx := createContext(s)
20-
email := "revoke_access." + s.TestInfo.Email
20+
email := "enable_access." + s.TestInfo.Email
2121
_, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{
2222
Email: email,
2323
})
@@ -45,7 +45,7 @@ func enableAccessTest(t *testing.T, s TestSetup) {
4545
assert.NoError(t, err)
4646
assert.NotEmpty(t, res.Message)
4747

48-
// it should allow login with revoked access
48+
// it should allow login with enabled access
4949
res, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{
5050
Email: email,
5151
})

server/test/resolvers_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ func TestResolvers(t *testing.T) {
1414
constants.DbTypeSqlite: "../../data.db",
1515
// constants.DbTypeArangodb: "http://localhost:8529",
1616
// constants.DbTypeMongodb: "mongodb://localhost:27017",
17+
// constants.DbTypeCassandraDB: "127.0.0.1:9042",
1718
}
1819

1920
for dbType, dbURL := range databases {

server/test/test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ func cleanData(email string) {
4646
err = db.Provider.DeleteVerificationRequest(verificationRequest)
4747
}
4848

49+
verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeMagicLinkLogin)
50+
if err == nil {
51+
err = db.Provider.DeleteVerificationRequest(verificationRequest)
52+
}
53+
4954
dbUser, err := db.Provider.GetUserByEmail(email)
5055
if err == nil {
5156
db.Provider.DeleteUser(dbUser)

0 commit comments

Comments
 (0)