diff --git a/internal/db/token_repo_redis.go b/internal/db/token_repo_redis.go index 32c03ed1..9a60e5b7 100644 --- a/internal/db/token_repo_redis.go +++ b/internal/db/token_repo_redis.go @@ -42,10 +42,14 @@ func (r RedisAdapter) SetAccessToken(ctx context.Context, token models.AuthToken return r.setAuthToken(ctx, token) } -func (r RedisAdapter) SetAccessTokenExpiry(ctx context.Context, token models.AuthToken, expiresAt time.Time) error { +func (r RedisAdapter) SetAccessTokenExpiry(ctx context.Context, token models.AuthToken, expiresAtLimit time.Time) error { if token.Type != models.AccessTokenType { return fmt.Errorf("token is not of the right type") } + expiresAt := expiresAtLimit + if !token.ExpiresAt.IsZero() && token.ExpiresAt.Before(expiresAtLimit) { + expiresAt = token.ExpiresAt + } return r.setAuthTokenExpiry(ctx, token, expiresAt) } @@ -57,10 +61,14 @@ func (r RedisAdapter) SetRefreshToken(ctx context.Context, token models.AuthToke return r.setAuthToken(ctx, token) } -func (r RedisAdapter) SetRefreshTokenExpiry(ctx context.Context, token models.AuthToken, expiresAt time.Time) error { +func (r RedisAdapter) SetRefreshTokenExpiry(ctx context.Context, token models.AuthToken, expiresAtLimit time.Time) error { if token.Type != models.RefreshTokenType { return fmt.Errorf("token is not of the right type") } + expiresAt := expiresAtLimit + if !token.ExpiresAt.IsZero() && token.ExpiresAt.Before(expiresAtLimit) { + expiresAt = token.ExpiresAt + } return r.setAuthTokenExpiry(ctx, token, expiresAt) } @@ -71,10 +79,14 @@ func (r RedisAdapter) SetIDToken(ctx context.Context, token models.AuthToken) er return r.setAuthToken(ctx, token) } -func (r RedisAdapter) SetIDTokenExpiry(ctx context.Context, token models.AuthToken, expiresAt time.Time) error { +func (r RedisAdapter) SetIDTokenExpiry(ctx context.Context, token models.AuthToken, expiresAtLimit time.Time) error { if token.Type != models.IDTokenType { return fmt.Errorf("token is not of the right type") } + expiresAt := expiresAtLimit + if !token.ExpiresAt.IsZero() && token.ExpiresAt.Before(expiresAtLimit) { + expiresAt = token.ExpiresAt + } return r.setAuthTokenExpiry(ctx, token, expiresAt) } diff --git a/internal/models/token_repository.go b/internal/models/token_repository.go index 54151676..6c374208 100644 --- a/internal/models/token_repository.go +++ b/internal/models/token_repository.go @@ -24,7 +24,7 @@ type AccessTokenGetter interface { type AccessTokenSetter interface { SetAccessToken(ctx context.Context, token AuthToken) error - SetAccessTokenExpiry(ctx context.Context, token AuthToken, expiresAt time.Time) error + SetAccessTokenExpiry(ctx context.Context, token AuthToken, expiresAtLimit time.Time) error } type AccessTokenRemover interface { @@ -37,7 +37,7 @@ type RefreshTokenGetter interface { type RefreshTokenSetter interface { SetRefreshToken(ctx context.Context, token AuthToken) error - SetRefreshTokenExpiry(ctx context.Context, token AuthToken, expiresAt time.Time) error + SetRefreshTokenExpiry(ctx context.Context, token AuthToken, expiresAtLimit time.Time) error } type RefreshTokenRemover interface { @@ -50,7 +50,7 @@ type IDTokenGetter interface { type IDTokenSetter interface { SetIDToken(ctx context.Context, token AuthToken) error - SetIDTokenExpiry(ctx context.Context, token AuthToken, expiresAt time.Time) error + SetIDTokenExpiry(ctx context.Context, token AuthToken, expiresAtLimit time.Time) error } type IDTokenRemover interface { diff --git a/internal/sessions/session_maker.go b/internal/sessions/session_maker.go index 4ccddad4..20208c5c 100644 --- a/internal/sessions/session_maker.go +++ b/internal/sessions/session_maker.go @@ -32,9 +32,9 @@ func (sm *SessionMakerImpl) NewSession() (models.Session, error) { if session.IdleTTL() == time.Duration(0) { session.ExpiresAt = time.Time{} } else if session.MaxTTL() == time.Duration(0) { - session.ExpiresAt = session.CreatedAt.Add(session.MaxTTL()) - } else { session.ExpiresAt = session.CreatedAt.Add(session.IdleTTL()) + } else { + session.ExpiresAt = session.CreatedAt.Add(session.MaxTTL()) } slog.Info("NEW SESSION", "session", session) return session, nil diff --git a/internal/sessions/token_handling.go b/internal/sessions/token_handling.go index 4c9150ec..5c84acd5 100644 --- a/internal/sessions/token_handling.go +++ b/internal/sessions/token_handling.go @@ -135,7 +135,7 @@ func (sessions *SessionStore) SaveTokens(c echo.Context, session *models.Session session.TokenIDs = models.SerializableMap{} } session.TokenIDs[providerID] = tokens.AccessToken.ID - expiresAt := sessions.getTokenStorageExpiration(tokens, *session) + expiresAt := sessions.getTokenStorageExpiration(*session) err = sessions.tokenStore.SetAccessToken(c.Request().Context(), tokens.AccessToken) if err != nil { return err @@ -175,11 +175,6 @@ func (*SessionStore) idTokenKey(tokenID string) string { return IDTokenCtxKey + ":" + tokenID } -// getTokenStorageExpiration returns the max session expiration unless the provider is Renku or GitLab, in which case there is no expiration -func (*SessionStore) getTokenStorageExpiration(tokens models.AuthTokenSet, session models.Session) time.Time { - providerID := tokens.AccessToken.ProviderID - if providerID == "renku" || providerID == "gitlab" { - return time.Time{} - } +func (*SessionStore) getTokenStorageExpiration(session models.Session) time.Time { return session.CreatedAt.Add(session.MaxTTL()) }