Skip to content

Commit dd240b0

Browse files
6543zeripathtechknowlogick
authored
Detect expired entries on read (#41)
Co-authored-by: zeripath <[email protected]> Co-authored-by: Andrew Thornton <[email protected]> Co-authored-by: techknowlogick <[email protected]>
1 parent bf67e28 commit dd240b0

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

file.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,15 @@ func (p *FileProvider) Read(sid string) (_ RawStore, err error) {
133133
defer p.lock.RUnlock()
134134

135135
var f *os.File
136+
expired := true
136137
if com.IsFile(filename) {
138+
modTime, err := com.FileMTime(filename)
139+
if err != nil {
140+
return nil, err
141+
}
142+
expired = (modTime + p.maxlifetime) < time.Now().Unix()
143+
}
144+
if !expired {
137145
f, err = os.OpenFile(filename, os.O_RDONLY, 0600)
138146
} else {
139147
f, err = os.Create(filename)

memory.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ type MemProvider struct {
9696
// Init initializes memory session provider.
9797
func (p *MemProvider) Init(maxLifetime int64, _ string) error {
9898
p.lock.Lock()
99+
p.list = list.New()
100+
p.data = make(map[string]*list.Element)
99101
p.maxLifetime = maxLifetime
100102
p.lock.Unlock()
101103
return nil
@@ -120,7 +122,8 @@ func (p *MemProvider) Read(sid string) (_ RawStore, err error) {
120122
e, ok := p.data[sid]
121123
p.lock.RUnlock()
122124

123-
if ok {
125+
// Only restore if the session is still alive.
126+
if ok && (e.Value.(*MemStore).lastAccess.Unix()+p.maxLifetime) >= time.Now().Unix() {
124127
if err = p.update(sid); err != nil {
125128
return nil, err
126129
}
@@ -130,7 +133,9 @@ func (p *MemProvider) Read(sid string) (_ RawStore, err error) {
130133
// Create a new session.
131134
p.lock.Lock()
132135
defer p.lock.Unlock()
133-
136+
if ok {
137+
p.list.Remove(e)
138+
}
134139
s := NewMemStore(sid)
135140
p.data[sid] = p.list.PushBack(s)
136141
return s, nil
@@ -213,5 +218,5 @@ func (p *MemProvider) GC() {
213218
}
214219

215220
func init() {
216-
Register("memory", &MemProvider{list: list.New(), data: make(map[string]*list.Element)})
221+
Register("memory", &MemProvider{})
217222
}

mysql/mysql.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,20 @@ func (p *MysqlProvider) Init(expire int64, connStr string) (err error) {
121121

122122
// Read returns raw session store by session ID.
123123
func (p *MysqlProvider) Read(sid string) (session.RawStore, error) {
124+
now := time.Now().Unix()
124125
var data []byte
125-
err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data)
126+
expiry := now
127+
err := p.c.QueryRow("SELECT data, expiry FROM session WHERE `key`=?", sid).Scan(&data, &expiry)
126128
if err == sql.ErrNoRows {
127129
_, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)",
128-
sid, "", time.Now().Unix())
130+
sid, "", now)
129131
}
130132
if err != nil {
131133
return nil, err
132134
}
133135

134136
var kv map[interface{}]interface{}
135-
if len(data) == 0 {
137+
if len(data) == 0 || expiry+p.expire <= now {
136138
kv = make(map[interface{}]interface{})
137139
} else {
138140
kv, err = session.DecodeGob(data)

postgres/postgres.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,20 @@ func (p *PostgresProvider) Init(maxlifetime int64, connStr string) (err error) {
122122

123123
// Read returns raw session store by session ID.
124124
func (p *PostgresProvider) Read(sid string) (session.RawStore, error) {
125+
now := time.Now().Unix()
125126
var data []byte
126-
err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data)
127+
expiry := now
128+
err := p.c.QueryRow("SELECT data, expiry FROM session WHERE key=$1", sid).Scan(&data, &expiry)
127129
if err == sql.ErrNoRows {
128130
_, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)",
129-
sid, "", time.Now().Unix())
131+
sid, "", now)
130132
}
131133
if err != nil {
132134
return nil, err
133135
}
134136

135137
var kv map[interface{}]interface{}
136-
if len(data) == 0 {
138+
if len(data) == 0 || expiry+p.maxlifetime <= now {
137139
kv = make(map[interface{}]interface{})
138140
} else {
139141
kv, err = session.DecodeGob(data)

0 commit comments

Comments
 (0)