Skip to content

Commit bcfcfd0

Browse files
committed
Refactor UserRepository
1 parent 4711e93 commit bcfcfd0

File tree

8 files changed

+69
-107
lines changed

8 files changed

+69
-107
lines changed

authenticator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (s *Authenticator) Authenticate(ctx context.Context, info AuthInfo) (AuthRe
139139
}
140140
}
141141

142-
user, er1 := s.Repository.GetUser(ctx, info)
142+
user, er1 := s.Repository.GetUser(ctx, info.Username)
143143
if er1 != nil {
144144
return result, er1
145145
}

cassandra/user_repository.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/gocql/gocql"
1212
)
1313

14-
type AuthenticationRepository struct {
14+
type UserRepository struct {
1515
Session *gocql.Session
1616
userTableName string
1717
passwordTableName string
@@ -41,16 +41,16 @@ type AuthenticationRepository struct {
4141
TwoFactorsName string
4242
}
4343

44-
func NewAuthenticationRepositoryByConfig(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status a.UserStatusConfig, c a.SchemaConfig, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
45-
return NewAuthenticationRepository(session, userTableName, passwordTableName, activatedStatus, status, c.Id, c.Username, c.UserId, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
44+
func NewUserRepositoryByConfig(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status a.UserStatusConfig, c a.SchemaConfig, options ...func(context.Context, string) (bool, error)) *UserRepository {
45+
return NewUserRepository(session, userTableName, passwordTableName, activatedStatus, status, c.Id, c.Username, c.UserId, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
4646
}
4747

48-
func NewAuthenticationRepository(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status a.UserStatusConfig, idName, userName, userID, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactorsName string, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
48+
func NewUserRepository(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status a.UserStatusConfig, idName, userName, userID, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactorsName string, options ...func(context.Context, string) (bool, error)) *UserRepository {
4949
var checkTwoFactors func(context.Context, string) (bool, error)
5050
if len(options) >= 1 {
5151
checkTwoFactors = options[0]
5252
}
53-
return &AuthenticationRepository{
53+
return &UserRepository{
5454
Session: session,
5555
userTableName: strings.ToLower(userTableName),
5656
passwordTableName: strings.ToLower(passwordTableName),
@@ -81,11 +81,12 @@ func NewAuthenticationRepository(session *gocql.Session, userTableName, password
8181
}
8282
}
8383

84-
func (r *AuthenticationRepository) GetUser(ctx context.Context, user a.AuthInfo) (*a.UserInfo, error) {
84+
func (r *UserRepository) GetUser(ctx context.Context, username string) (*a.UserInfo, error) {
8585
session := r.Session
8686
userInfo := a.UserInfo{}
8787
query := "SELECT * FROM " + r.userTableName + " WHERE " + r.UserName + " = ? ALLOW FILTERING"
88-
raws := session.Query(query, user.Username).Iter()
88+
raws := session.Query(query, username).Iter()
89+
userInfo.Username = username
8990
for {
9091
// New map each iteration
9192
row := make(map[string]interface{})
@@ -206,11 +207,11 @@ func (r *AuthenticationRepository) GetUser(ctx context.Context, user a.AuthInfo)
206207
return &userInfo, nil
207208
}
208209

209-
func (r *AuthenticationRepository) Pass(ctx context.Context, userId string, deactivated *bool) error {
210+
func (r *UserRepository) Pass(ctx context.Context, userId string, deactivated *bool) error {
210211
_, err := r.passAuthenticationAndActivate(ctx, userId, deactivated)
211212
return err
212213
}
213-
func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Context, userId string, updateStatus *bool) (int64, error) {
214+
func (r *UserRepository) passAuthenticationAndActivate(ctx context.Context, userId string, updateStatus *bool) (int64, error) {
214215
session := r.Session
215216
if len(r.SuccessTimeName) == 0 && len(r.FailCountName) == 0 && len(r.LockedUntilTimeName) == 0 {
216217
if updateStatus != nil && !*updateStatus {
@@ -253,7 +254,7 @@ func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Con
253254
return k1 + k2, err1
254255
}
255256

256-
func (r *AuthenticationRepository) Fail(ctx context.Context, userId string, failCount *int, lockedUntil *time.Time) error {
257+
func (r *UserRepository) Fail(ctx context.Context, userId string, failCount *int, lockedUntil *time.Time) error {
257258
if len(r.FailTimeName) == 0 && len(r.FailCountName) == 0 && len(r.LockedUntilTimeName) == 0 {
258259
return nil
259260
}

dynamodb/user_repository.go

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"time"
1515
)
1616

17-
type AuthenticationRepository struct {
17+
type UserRepository struct {
1818
Db *dynamodb.DynamoDB
1919
UserTableName string
2020
PasswordTableName string
@@ -43,16 +43,16 @@ type AuthenticationRepository struct {
4343
TwoFactorsName string
4444
}
4545

46-
func NewAuthenticationRepositoryByConfig(dynamoDB *dynamodb.DynamoDB, userTableName, passwordTableName string, activatedStatus interface{}, status auth.UserStatusConfig, c auth.SchemaConfig, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
47-
return NewAuthenticationRepository(dynamoDB, userTableName, passwordTableName, activatedStatus, status, c.Username, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.Roles, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
46+
func NewUserRepositoryByConfig(dynamoDB *dynamodb.DynamoDB, userTableName, passwordTableName string, activatedStatus interface{}, status auth.UserStatusConfig, c auth.SchemaConfig, options ...func(context.Context, string) (bool, error)) *UserRepository {
47+
return NewUserRepository(dynamoDB, userTableName, passwordTableName, activatedStatus, status, c.Username, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.Roles, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
4848
}
4949

50-
func NewAuthenticationRepository(dynamoDB *dynamodb.DynamoDB, userTableName, passwordTableName string, activatedStatus interface{}, status auth.UserStatusConfig, userName, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, rolesName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactors string, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
50+
func NewUserRepository(dynamoDB *dynamodb.DynamoDB, userTableName, passwordTableName string, activatedStatus interface{}, status auth.UserStatusConfig, userName, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, rolesName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactors string, options ...func(context.Context, string) (bool, error)) *UserRepository {
5151
var checkTwoFactors func(context.Context, string) (bool, error)
5252
if len(options) > 0 && options[0] != nil {
5353
checkTwoFactors = options[0]
5454
}
55-
return &AuthenticationRepository{
55+
return &UserRepository{
5656
Db: dynamoDB,
5757
UserTableName: userTableName,
5858
PasswordTableName: passwordTableName,
@@ -82,7 +82,7 @@ func NewAuthenticationRepository(dynamoDB *dynamodb.DynamoDB, userTableName, pas
8282
}
8383
}
8484

85-
func (r *AuthenticationRepository) GetUserInfo(ctx context.Context, username string) (*auth.UserInfo, error) {
85+
func (r *UserRepository) GetUser(ctx context.Context, username string) (*auth.UserInfo, error) {
8686
userInfo := auth.UserInfo{}
8787
filter := expression.Equal(expression.Name("_id"), expression.Value(username))
8888
expr, _ := expression.NewBuilder().WithFilter(filter).Build()
@@ -101,6 +101,7 @@ func (r *AuthenticationRepository) GetUserInfo(ctx context.Context, username str
101101
if er1 != nil {
102102
return nil, er1
103103
}
104+
userInfo.Username = username
104105
if len(r.StatusName) > 0 {
105106
rawStatus := raw[r.StatusName]
106107
status, ok := rawStatus.(string)
@@ -115,42 +116,43 @@ func (r *AuthenticationRepository) GetUserInfo(ctx context.Context, username str
115116
}
116117
}
117118
}
118-
userInfo.Deactivated = status == r.Status.Deactivated
119+
b := status == r.Status.Deactivated
120+
userInfo.Deactivated = &b
119121
userInfo.Suspended = status == r.Status.Suspended
120122
userInfo.Disable = status == r.Status.Disable
121123
}
122124

123125
if len(r.ContactName) > 0 {
124126
if contact, ok := raw[r.ContactName].(string); ok {
125-
userInfo.Contact = contact
127+
userInfo.Contact = &contact
126128
}
127129
}
128130
if len(r.EmailName) > 0 {
129131
if email, ok := raw[r.EmailName].(string); ok {
130-
userInfo.Email = email
132+
userInfo.Email = &email
131133
}
132134
}
133135
if len(r.PhoneName) > 0 {
134136
if phone, ok := raw[r.PhoneName].(string); ok {
135-
userInfo.Phone = phone
137+
userInfo.Phone = &phone
136138
}
137139
}
138140

139141
if len(r.DisplayNameName) > 0 {
140142
if displayName, ok := raw[r.DisplayNameName].(string); ok {
141-
userInfo.DisplayName = displayName
143+
userInfo.DisplayName = &displayName
142144
}
143145
}
144146

145147
if len(r.MaxPasswordAgeName) > 0 {
146148
if maxPasswordAgeName, ok := raw[r.MaxPasswordAgeName].(int32); ok {
147-
userInfo.MaxPasswordAge = maxPasswordAgeName
149+
userInfo.MaxPasswordAge = &maxPasswordAgeName
148150
}
149151
}
150152

151153
if len(r.UserTypeName) > 0 {
152154
if userType, ok := raw[r.UserTypeName].(string); ok {
153-
userInfo.UserType = userType
155+
userInfo.UserType = &userType
154156
}
155157
}
156158

@@ -236,7 +238,7 @@ func getTime(accessTime string) *time.Time {
236238
return nil
237239
}
238240

239-
func (r *AuthenticationRepository) getPasswordInfo(ctx context.Context, user *auth.UserInfo, raw map[string]interface{}) *auth.UserInfo {
241+
func (r *UserRepository) getPasswordInfo(ctx context.Context, user *auth.UserInfo, raw map[string]interface{}) *auth.UserInfo {
240242
if len(r.PasswordName) > 0 {
241243
if pass, ok := raw[r.PasswordName].(string); ok {
242244
user.Password = pass
@@ -263,7 +265,8 @@ func (r *AuthenticationRepository) getPasswordInfo(ctx context.Context, user *au
263265

264266
if len(r.FailCountName) > 0 {
265267
if failCountName, ok := raw[r.FailCountName].(int32); ok {
266-
user.FailCount = int(failCountName)
268+
i := int(failCountName)
269+
user.FailCount = &i
267270
}
268271
}
269272

@@ -275,15 +278,16 @@ func (r *AuthenticationRepository) getPasswordInfo(ctx context.Context, user *au
275278
return user
276279
}
277280

278-
func (r *AuthenticationRepository) Pass(ctx context.Context, userId string) (int64, error) {
279-
return r.passAuthenticationAndActivate(ctx, userId, false)
281+
func (r *UserRepository) Pass(ctx context.Context, userId string) error {
282+
_, err := r.passAuthenticationAndActivate(ctx, userId, false)
283+
return err
280284
}
281285

282-
func (r *AuthenticationRepository) PassAndActivate(ctx context.Context, userId string) (int64, error) {
286+
func (r *UserRepository) PassAndActivate(ctx context.Context, userId string) (int64, error) {
283287
return r.passAuthenticationAndActivate(ctx, userId, true)
284288
}
285289

286-
func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Context, userId string, updateStatus bool) (int64, error) {
290+
func (r *UserRepository) passAuthenticationAndActivate(ctx context.Context, userId string, updateStatus bool) (int64, error) {
287291
if len(r.SuccessTimeName) == 0 && len(r.FailCountName) == 0 && len(r.LockedUntilTimeName) == 0 {
288292
if !updateStatus || len(r.StatusName) == 0 {
289293
return 0, nil
@@ -318,7 +322,7 @@ func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Con
318322
return k1 + k2, er2
319323
}
320324

321-
func (r *AuthenticationRepository) Fail(ctx context.Context, userId string, failCount int, lockedUntil *time.Time) error {
325+
func (r *UserRepository) Fail(ctx context.Context, userId string, failCount *int, lockedUntil *time.Time) error {
322326
if len(r.FailTimeName) == 0 && len(r.FailCountName) == 0 && len(r.LockedUntilTimeName) == 0 {
323327
return nil
324328
}
@@ -327,8 +331,8 @@ func (r *AuthenticationRepository) Fail(ctx context.Context, userId string, fail
327331
if len(r.FailTimeName) > 0 {
328332
pass[r.FailTimeName] = time.Now()
329333
}
330-
if len(r.FailCountName) > 0 {
331-
pass[r.FailCountName] = failCount
334+
if len(r.FailCountName) > 0 && failCount != nil {
335+
pass[r.FailCountName] = *failCount + 1
332336
if len(r.LockedUntilTimeName) > 0 {
333337
pass[r.LockedUntilTimeName] = lockedUntil
334338
}

0 commit comments

Comments
 (0)