@@ -10,7 +10,7 @@ import (
1010 "time"
1111)
1212
13- type ICacheService interface {
13+ type CachePort interface {
1414 Get (ctx context.Context , key string ) (string , error )
1515 Remove (ctx context.Context , key string ) (bool , error )
1616 Expire (ctx context.Context , key string , timeToLive time.Duration ) (bool , error )
@@ -28,14 +28,14 @@ type SessionAuthorizer struct {
2828 DecodeSessionID func (value string ) (string , error )
2929 EncodeSessionID func (sid string ) string
3030 VerifyToken func (tokenString string , secret string ) (map [string ]interface {}, int64 , int64 , error )
31- Cache ICacheService
31+ Cache CachePort
3232 sessionExpiredTime time.Duration
3333 LogError func (ctx context.Context , msg string , opts ... map [string ]interface {})
3434}
3535
3636func NewSessionAuthorizer (secretKey string , verifyToken func (tokenString string , secret string ) (map [string ]interface {}, int64 , int64 , error ),
3737 refreshExpire func (w http.ResponseWriter , sessionId string ) error ,
38- cache ICacheService , sessionExpiredTime time.Duration , logError func (ctx context.Context , msg string , opts ... map [string ]interface {}), singleSession bool ,
38+ cache CachePort , sessionExpiredTime time.Duration , logError func (ctx context.Context , msg string , opts ... map [string ]interface {}), singleSession bool ,
3939 encodeSessionID func (sid string ) string ,
4040 decodeSessionID func (value string ) (string , error ),
4141 opts ... string ) * SessionAuthorizer {
@@ -139,6 +139,9 @@ func (h *SessionAuthorizer) Authorize(next http.Handler, skipRefreshTTL bool) ht
139139 return
140140 }
141141 ip := getForwardedRemoteIp (r )
142+ if len (ip ) == 0 {
143+ ip = getRemoteIp (r )
144+ }
142145 sid , ok := uData [h .SId ]
143146 if ! ok || sid != sessionId ||
144147 getValue (uData , "userAgent" ) != r .UserAgent () ||
@@ -180,6 +183,9 @@ func (h *SessionAuthorizer) Verify(next http.Handler, skipRefreshTTL bool, sessi
180183 return
181184 }
182185 ip := getForwardedRemoteIp (r )
186+ if len (ip ) == 0 {
187+ ip = getRemoteIp (r )
188+ }
183189 ctx = context .WithValue (ctx , "ip" , ip )
184190 for k , e := range payload {
185191 if len (k ) > 0 {
@@ -232,7 +238,13 @@ func getForwardedRemoteIp(r *http.Request) string {
232238 }
233239 return ""
234240}
235-
241+ func getRemoteIp (r * http.Request ) string {
242+ remoteIP , _ , err := net .SplitHostPort (r .RemoteAddr )
243+ if err != nil {
244+ remoteIP = r .RemoteAddr
245+ }
246+ return remoteIP
247+ }
236248func getValue (data map [string ]interface {}, key string ) string {
237249 if value , ok := data [key ]; ok {
238250 return value .(string )
0 commit comments