Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ import (
"github.com/gravitational/teleport/lib/join"
"github.com/gravitational/teleport/lib/join/joinv1"
"github.com/gravitational/teleport/lib/join/legacyjoin"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/services"
Expand Down Expand Up @@ -228,6 +229,13 @@ type GRPCServer struct {
// in-flight CreateAuditStream RPCs, by sending a value in at the beginning
// of the RPC and pulling one out before returning.
createAuditStreamSemaphore chan struct{}

// createAuthenticateChallengeLimiter is a rate limiter for invocations of
// /proto.AuthService/CreateAuthenticateChallenge that don't rely on a user
// context and thus warrant additional rate limiting since they are
// unauthenticated (either through direct API connections or coming from the
// proxy on behalf of a remote unauthenticated user).
createAuthenticateChallengeLimiter *limiter.RateLimiter
}

func (g *GRPCServer) SetServingStatus(service string, servingStatus grpc_health_v1.HealthCheckResponse_ServingStatus) {
Expand Down Expand Up @@ -4826,6 +4834,21 @@ func (g *GRPCServer) CreateAuthenticateChallenge(ctx context.Context, req *authp
return nil, trace.Wrap(err)
}

if req.GetContextUser() == nil {
// requests with a user context will be checked for legitimacy by
// ServerWithRoles later, so we only need to care about adding an
// additional rate limit for the non-user-context requests here

peerInfo, ok := peer.FromContext(ctx)
if !ok {
return nil, trace.BadParameter("unable to find peer")
}

if err := g.createAuthenticateChallengeLimiter.RegisterRequestFromAddr(peerInfo.Addr, nil); err != nil {
return nil, trace.LimitExceeded("rate limit exceeded")
}
}

res, err := actx.ServerWithRoles.CreateAuthenticateChallenge(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -5941,6 +5964,10 @@ type GRPCServerConfig struct {
UnaryInterceptors []grpc.UnaryServerInterceptor
// StreamInterceptors is the gRPC stream interceptor chain.
StreamInterceptors []grpc.StreamServerInterceptor
// CreateAuthenticateChallengeLimiterConfig is the optional configuration
// for the limiter applied to unauthenticated calls to
// CreateAuthenticateChallenge. Used in tests.
CreateAuthenticateChallengeLimiterConfig *limiter.Config
}

// CheckAndSetDefaults checks and sets default values
Expand Down Expand Up @@ -6222,11 +6249,28 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
}
grpcv1pb.RegisterServiceConfigDiscoveryServiceServer(server, grpcClientConfigService)

limiterConfig := limiter.Config{
Rates: []limiter.Rate{{
Period: defaults.LimiterPeriod,
Average: defaults.LimiterAverage,
Burst: defaults.LimiterBurst,
}},
}
if cfg.CreateAuthenticateChallengeLimiterConfig != nil {
limiterConfig = *cfg.CreateAuthenticateChallengeLimiterConfig
}
createAuthenticateChallengeLimiter, err := limiter.NewRateLimiter(limiterConfig)
if err != nil {
return nil, trace.Wrap(err)
}

authServer := &GRPCServer{
APIConfig: cfg.APIConfig,
logger: logger,
server: server,
healthcheck: health.NewServer(),

createAuthenticateChallengeLimiter: createAuthenticateChallengeLimiter,
}

if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" {
Expand Down
104 changes: 80 additions & 24 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"net/http/httptest"
"sort"
"testing"
"testing/synctest"
"time"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -4643,61 +4644,50 @@ func TestListResources(t *testing.T) {
}

func TestCustomRateLimiting(t *testing.T) {
t.Parallel()

ctx := context.Background()
tests := []struct {
name string
burst int
fn func(*authclient.Client) error
fn func(context.Context, *authclient.Client) error
}{
{
name: "RPC ChangeUserAuthentication",
fn: func(clt *authclient.Client) error {
fn: func(ctx context.Context, clt *authclient.Client) error {
_, err := clt.ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{})
return err
},
},
{
name: "RPC CreateAuthenticateChallenge",
burst: defaults.LimiterBurst,
fn: func(clt *authclient.Client) error {
_, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{})
return err
},
},
{
name: "RPC GetAccountRecoveryToken",
fn: func(clt *authclient.Client) error {
fn: func(ctx context.Context, clt *authclient.Client) error {
_, err := clt.GetAccountRecoveryToken(ctx, &proto.GetAccountRecoveryTokenRequest{})
return err
},
},
{
name: "RPC StartAccountRecovery",
fn: func(clt *authclient.Client) error {
fn: func(ctx context.Context, clt *authclient.Client) error {
_, err := clt.StartAccountRecovery(ctx, &proto.StartAccountRecoveryRequest{})
return err
},
},
{
name: "RPC VerifyAccountRecovery",
fn: func(clt *authclient.Client) error {
fn: func(ctx context.Context, clt *authclient.Client) error {
_, err := clt.VerifyAccountRecovery(ctx, &proto.VerifyAccountRecoveryRequest{})
return err
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

synctestCase := func(t *testing.T) {
// Create new instance per test case, to troubleshoot which test case
// specifically failed, otherwise multiple cases can fail from running
// cases in parallel.
srv := newTestTLSServer(t)
srv := newTestTLSServer(t, withBufconnListener())
defer srv.Close()
clt, err := srv.NewClient(authtest.TestNop())
require.NoError(t, err)
defer clt.Close()

var attempts int
if test.burst == 0 {
Expand All @@ -4707,13 +4697,79 @@ func TestCustomRateLimiting(t *testing.T) {
}

for range attempts {
err = test.fn(clt)
require.False(t, trace.IsLimitExceeded(err), "got err = %v, want non-IsLimitExceeded", err)
err = test.fn(t.Context(), clt)
require.NotErrorAs(t, err, new(*trace.LimitExceededError))
}

err = test.fn(clt)
require.True(t, trace.IsLimitExceeded(err), "got err = %v, want LimitExceeded", err)
err = test.fn(t.Context(), clt)
require.ErrorAs(t, err, new(*trace.LimitExceededError))
}
t.Run(test.name, func(t *testing.T) {
synctest.Test(t, synctestCase)
})
}

t.Run("unauthenticated CreateAuthenticateChallenge", func(t *testing.T) {
synctest.Test(t, synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge)
})
}

func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *testing.T) {
ctx := t.Context()

srv := newTestTLSServer(t, withBufconnListener())
defer srv.Close()
clt, err := srv.NewClient(authtest.TestNop())
require.NoError(t, err)
defer clt.Close()

for range defaults.LimiterBurst {
_, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.NotErrorAs(t, err, new(*trace.LimitExceededError))
}
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))

time.Sleep(defaults.LimiterPeriod)

for range defaults.LimiterBurst - defaults.LimiterAverage {
_, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.NotErrorAs(t, err, new(*trace.LimitExceededError))
}
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))

_ = func() {
// static assertion that waiting 1000 periods should bring us to LimiterBurst but no higher
const mustBeTrue = 1000*defaults.LimiterAverage > defaults.LimiterBurst
_ = map[bool]struct{}{false: struct{}{}, mustBeTrue: struct{}{}}
}
time.Sleep(1000 * defaults.LimiterPeriod)

for range defaults.LimiterBurst {
_, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.NotErrorAs(t, err, new(*trace.LimitExceededError))
}
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))
_, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest))
require.ErrorAs(t, err, new(*trace.LimitExceededError))

// no time has passed, but we can do a full burst and more if we pretend to
// have a user context (the request will fail, but not due to the rate
// limiter)

for range defaults.LimiterBurst + 1 {
_, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
},
})
require.NotErrorAs(t, err, new(*trace.LimitExceededError))
}
}

Expand Down
20 changes: 6 additions & 14 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ type TLSServerConfig struct {
APIConfig
// LimiterConfig is limiter config
LimiterConfig limiter.Config
// CreateAuthenticateChallengeLimiterConfig is the optional configuration
// for the limiter applied to unauthenticated calls to
// CreateAuthenticateChallenge. Used in tests.
CreateAuthenticateChallengeLimiterConfig *limiter.Config
// AccessPoint is a caching access point
AccessPoint AccessCacheWithEvents
// Component is used for debugging purposes
Expand Down Expand Up @@ -234,6 +238,8 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error)
APIConfig: cfg.APIConfig,
UnaryInterceptors: authMiddleware.UnaryInterceptors(),
StreamInterceptors: authMiddleware.StreamInterceptors(),

CreateAuthenticateChallengeLimiterConfig: cfg.CreateAuthenticateChallengeLimiterConfig,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -376,20 +382,6 @@ func getCustomRate(endpoint string) *limiter.RateSet {
return nil
}
return rates
// Passwordless RPCs (potential unauthenticated challenge generation).
case "/proto.AuthService/CreateAuthenticateChallenge":
const period = defaults.LimiterPeriod
const average = defaults.LimiterAverage
const burst = defaults.LimiterBurst
rates := limiter.NewRateSet()
if err := rates.Add(period, average, burst); err != nil {
logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate",
"error", err,
"rpc_method", endpoint,
)
return nil
}
return rates
}
return nil
}
Expand Down
20 changes: 13 additions & 7 deletions lib/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (l *Limiter) UnaryServerInterceptorWithCustomRate(customRate CustomRateFunc
}

// Limit requests per second and simultaneous connection by client IP.
clientIP, err := clientIPFromPeer(peerInfo)
clientIP, err := clientIPFromAddr(peerInfo.Addr)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -162,7 +162,7 @@ func (l *Limiter) StreamServerInterceptor(srv any, serverStream grpc.ServerStrea
return trace.AccessDenied("missing peer info")
}
// Limit requests per second and simultaneous connection by client IP.
clientIP, err := clientIPFromPeer(peerInfo)
clientIP, err := clientIPFromAddr(peerInfo.Addr)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -176,18 +176,24 @@ func (l *Limiter) StreamServerInterceptor(srv any, serverStream grpc.ServerStrea
return handler(srv, serverStream)
}

func clientIPFromPeer(peerInfo *peer.Peer) (string, error) {
clientIP, _, err := net.SplitHostPort(peerInfo.Addr.String())
if err == nil {
return clientIP, nil
func clientIPFromAddr(addr net.Addr) (string, error) {
if addr == nil {
return "", trace.BadParameter("missing client IP")
}

s := addr.String()

// bufconn peers don't include host:port, so use a stable synthetic key
// for request/connection limiting in tests.
if peerInfo.Addr.Network() == "bufconn" && peerInfo.Addr.String() == "bufconn" {
if s == "bufconn" && addr.Network() == "bufconn" {
return "bufconn", nil
}

clientIP, _, err := net.SplitHostPort(s)
if err == nil {
return clientIP, nil
}

return "", trace.BadParameter("missing client IP")
}

Expand Down
12 changes: 12 additions & 0 deletions lib/limiter/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package limiter
import (
"cmp"
"context"
"net"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -150,6 +151,17 @@ func (l *RateLimiter) RegisterRequest(token string, customRate *ratelimit.RateSe
return nil
}

// RegisterRequestFromAddr increases the number of requests coming for the given
// remote address, returning an error if the address is invalid or if there have
// been too many requests from that address recently.
func (l *RateLimiter) RegisterRequestFromAddr(addr net.Addr, customRate *ratelimit.RateSet) error {
token, err := clientIPFromAddr(addr)
if err != nil {
return trace.Wrap(err)
}
return l.RegisterRequest(token, customRate)
}

// Add rate limiter to the handle
func (l *RateLimiter) WrapHandle(h http.Handler) {
l.TokenLimiter.Wrap(h)
Expand Down
17 changes: 10 additions & 7 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2758,13 +2758,16 @@ func (process *TeleportProcess) initAuthService() error {
TLS: tlsConfig,
GetClientCertificate: connector.ClientGetCertificate,

APIConfig: *apiConf,
LimiterConfig: cfg.Auth.Limiter,
AccessPoint: authServer.Cache,
Component: teleport.Component(teleport.ComponentAuth, process.id),
ID: process.id,
Listener: mux.TLS(),
Metrics: authMetrics,
APIConfig: *apiConf,

LimiterConfig: cfg.Auth.Limiter,
CreateAuthenticateChallengeLimiterConfig: cfg.Auth.CreateAuthenticateChallengeLimiterConfig,

AccessPoint: authServer.Cache,
Component: teleport.Component(teleport.ComponentAuth, process.id),
ID: process.id,
Listener: mux.TLS(),
Metrics: authMetrics,
})
if err != nil {
return trace.Wrap(err)
Expand Down
5 changes: 5 additions & 0 deletions lib/service/servicecfg/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ type AuthConfig struct {

Limiter limiter.Config

// CreateAuthenticateChallengeLimiterConfig is the optional configuration
// for the limiter applied to unauthenticated calls to
// CreateAuthenticateChallenge. Used in tests.
CreateAuthenticateChallengeLimiterConfig *limiter.Config

// NoAudit, when set to true, disables session recording and event audit
NoAudit bool

Expand Down
Loading
Loading