From 64627159942e3b99fe5edaec9ed140eaad190e69 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 12 Mar 2026 12:39:14 +0100 Subject: [PATCH 01/11] Add a limiter.RateLimiter helper for peer addresses --- lib/limiter/limiter.go | 20 +++++++++++++------- lib/limiter/ratelimiter.go | 12 ++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/lib/limiter/limiter.go b/lib/limiter/limiter.go index b751c878b0cf1..276e462a7417b 100644 --- a/lib/limiter/limiter.go +++ b/lib/limiter/limiter.go @@ -139,7 +139,7 @@ func (l *Limiter) UnaryServerInterceptorWithCustomRate(customRate CustomRateFunc } // Limit requests per second and simultaneous connection by client IP. - clientIP, err := clientIPFromPeer(peerInfo) + clientIP, err := clientIPFromAddr(peerInfo.Addr) if err != nil { return nil, trace.Wrap(err) } @@ -162,7 +162,7 @@ func (l *Limiter) StreamServerInterceptor(srv any, serverStream grpc.ServerStrea return trace.AccessDenied("missing peer info") } // Limit requests per second and simultaneous connection by client IP. - clientIP, err := clientIPFromPeer(peerInfo) + clientIP, err := clientIPFromAddr(peerInfo.Addr) if err != nil { return trace.Wrap(err) } @@ -176,18 +176,24 @@ func (l *Limiter) StreamServerInterceptor(srv any, serverStream grpc.ServerStrea return handler(srv, serverStream) } -func clientIPFromPeer(peerInfo *peer.Peer) (string, error) { - clientIP, _, err := net.SplitHostPort(peerInfo.Addr.String()) - if err == nil { - return clientIP, nil +func clientIPFromAddr(addr net.Addr) (string, error) { + if addr == nil { + return "", trace.BadParameter("missing client IP") } + s := addr.String() + // bufconn peers don't include host:port, so use a stable synthetic key // for request/connection limiting in tests. - if peerInfo.Addr.Network() == "bufconn" && peerInfo.Addr.String() == "bufconn" { + if s == "bufconn" && addr.Network() == "bufconn" { return "bufconn", nil } + clientIP, _, err := net.SplitHostPort(s) + if err == nil { + return clientIP, nil + } + return "", trace.BadParameter("missing client IP") } diff --git a/lib/limiter/ratelimiter.go b/lib/limiter/ratelimiter.go index 22c55a9c9c60d..8bfabf707fae9 100644 --- a/lib/limiter/ratelimiter.go +++ b/lib/limiter/ratelimiter.go @@ -21,6 +21,7 @@ package limiter import ( "cmp" "context" + "net" "net/http" "sync" "time" @@ -150,6 +151,17 @@ func (l *RateLimiter) RegisterRequest(token string, customRate *ratelimit.RateSe return nil } +// RegisterRequestFromAddr increases the number of requests coming for the given +// remote address, returning an error if the address is invalid or if there have +// been too many requests from that address recently. +func (l *RateLimiter) RegisterRequestFromAddr(addr net.Addr, customRate *ratelimit.RateSet) error { + token, err := clientIPFromAddr(addr) + if err != nil { + return trace.Wrap(err) + } + return l.RegisterRequest(token, customRate) +} + // Add rate limiter to the handle func (l *RateLimiter) WrapHandle(h http.Handler) { l.TokenLimiter.Wrap(h) From 2afb5e63ccc376b657cf939c928cc50a4a63c293 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 12 Mar 2026 12:51:55 +0100 Subject: [PATCH 02/11] Add custom rate limit for CreateAuthenticateChallenge --- lib/auth/grpcserver.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index f51df221fb4d6..ae33ce8ef4642 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -148,6 +148,7 @@ import ( "github.com/gravitational/teleport/lib/join" "github.com/gravitational/teleport/lib/join/joinv1" "github.com/gravitational/teleport/lib/join/legacyjoin" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" @@ -228,6 +229,13 @@ type GRPCServer struct { // in-flight CreateAuditStream RPCs, by sending a value in at the beginning // of the RPC and pulling one out before returning. createAuditStreamSemaphore chan struct{} + + // createAuthenticateChallengeUnauthenticatedLimiter is a rate limiter for + // invocations of /proto.AuthService/CreateAuthenticateChallenge that don't + // rely on a user context and thus warrant additional rate limiting since + // they are unauthenticated (either through direct API connections or coming + // from the proxy on behalf of a remote unauthenticated user). + createAuthenticateChallengeUnauthenticatedLimiter *limiter.RateLimiter } func (g *GRPCServer) SetServingStatus(service string, servingStatus grpc_health_v1.HealthCheckResponse_ServingStatus) { @@ -4826,6 +4834,21 @@ func (g *GRPCServer) CreateAuthenticateChallenge(ctx context.Context, req *authp return nil, trace.Wrap(err) } + if req.GetContextUser() == nil { + // requests with a user context will be checked for legitimacy by + // ServerWithRoles later, so we only need to care about adding an + // additional rate limit for the non-user-context requests here + + peerInfo, ok := peer.FromContext(ctx) + if !ok { + return nil, trace.BadParameter("unable to find peer") + } + + if err := g.createAuthenticateChallengeUnauthenticatedLimiter.RegisterRequestFromAddr(peerInfo.Addr, nil); err != nil { + return nil, trace.Wrap(err) + } + } + res, err := actx.ServerWithRoles.CreateAuthenticateChallenge(ctx, req) if err != nil { return nil, trace.Wrap(err) @@ -6222,11 +6245,24 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { } grpcv1pb.RegisterServiceConfigDiscoveryServiceServer(server, grpcClientConfigService) + createAuthenticateChallengeUnauthenticatedLimiter, err := limiter.NewRateLimiter(limiter.Config{ + Rates: []limiter.Rate{{ + Period: defaults.LimiterPeriod, + Average: defaults.LimiterAverage, + Burst: defaults.LimiterBurst, + }}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + authServer := &GRPCServer{ APIConfig: cfg.APIConfig, logger: logger, server: server, healthcheck: health.NewServer(), + + createAuthenticateChallengeUnauthenticatedLimiter: createAuthenticateChallengeUnauthenticatedLimiter, } if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" { From febf68437812a0c15c41bc35c9f9aa518063f1e7 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 12 Mar 2026 12:52:33 +0100 Subject: [PATCH 03/11] Remove endpoint-based rate limit for CreateAuthenticateChallenge --- lib/auth/middleware.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index be041571055ce..3722725af76ae 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -376,20 +376,6 @@ func getCustomRate(endpoint string) *limiter.RateSet { return nil } return rates - // Passwordless RPCs (potential unauthenticated challenge generation). - case "/proto.AuthService/CreateAuthenticateChallenge": - const period = defaults.LimiterPeriod - const average = defaults.LimiterAverage - const burst = defaults.LimiterBurst - rates := limiter.NewRateSet() - if err := rates.Add(period, average, burst); err != nil { - logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate", - "error", err, - "rpc_method", endpoint, - ) - return nil - } - return rates } return nil } From 067791020bacd26a9f7b96b681ac3b69f52ae2d5 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 12 Mar 2026 13:10:27 +0100 Subject: [PATCH 04/11] Tweak existing rate limit test to use synctest --- lib/auth/grpcserver_test.go | 104 +++++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 24 deletions(-) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 04fce5f7f1106..e556d48a76bd5 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -31,6 +31,7 @@ import ( "net/http/httptest" "sort" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -4643,61 +4644,50 @@ func TestListResources(t *testing.T) { } func TestCustomRateLimiting(t *testing.T) { - t.Parallel() - - ctx := context.Background() tests := []struct { name string burst int - fn func(*authclient.Client) error + fn func(context.Context, *authclient.Client) error }{ { name: "RPC ChangeUserAuthentication", - fn: func(clt *authclient.Client) error { + fn: func(ctx context.Context, clt *authclient.Client) error { _, err := clt.ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{}) return err }, }, - { - name: "RPC CreateAuthenticateChallenge", - burst: defaults.LimiterBurst, - fn: func(clt *authclient.Client) error { - _, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{}) - return err - }, - }, { name: "RPC GetAccountRecoveryToken", - fn: func(clt *authclient.Client) error { + fn: func(ctx context.Context, clt *authclient.Client) error { _, err := clt.GetAccountRecoveryToken(ctx, &proto.GetAccountRecoveryTokenRequest{}) return err }, }, { name: "RPC StartAccountRecovery", - fn: func(clt *authclient.Client) error { + fn: func(ctx context.Context, clt *authclient.Client) error { _, err := clt.StartAccountRecovery(ctx, &proto.StartAccountRecoveryRequest{}) return err }, }, { name: "RPC VerifyAccountRecovery", - fn: func(clt *authclient.Client) error { + fn: func(ctx context.Context, clt *authclient.Client) error { _, err := clt.VerifyAccountRecovery(ctx, &proto.VerifyAccountRecoveryRequest{}) return err }, }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - + synctestCase := func(t *testing.T) { // Create new instance per test case, to troubleshoot which test case // specifically failed, otherwise multiple cases can fail from running // cases in parallel. - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withBufconnListener()) + defer srv.Close() clt, err := srv.NewClient(authtest.TestNop()) require.NoError(t, err) + defer clt.Close() var attempts int if test.burst == 0 { @@ -4707,13 +4697,79 @@ func TestCustomRateLimiting(t *testing.T) { } for range attempts { - err = test.fn(clt) - require.False(t, trace.IsLimitExceeded(err), "got err = %v, want non-IsLimitExceeded", err) + err = test.fn(t.Context(), clt) + require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } - err = test.fn(clt) - require.True(t, trace.IsLimitExceeded(err), "got err = %v, want LimitExceeded", err) + err = test.fn(t.Context(), clt) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + } + t.Run(test.name, func(t *testing.T) { + synctest.Test(t, synctestCase) + }) + } + + t.Run("unauthenticated CreateAuthenticateChallenge", func(t *testing.T) { + synctest.Test(t, synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge) + }) +} + +func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *testing.T) { + ctx := t.Context() + + srv := newTestTLSServer(t, withBufconnListener()) + defer srv.Close() + clt, err := srv.NewClient(authtest.TestNop()) + require.NoError(t, err) + defer clt.Close() + + for range defaults.LimiterBurst { + _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.NotErrorAs(t, err, new(*trace.LimitExceededError)) + } + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + + time.Sleep(defaults.LimiterPeriod) + + for range defaults.LimiterBurst - defaults.LimiterAverage { + _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.NotErrorAs(t, err, new(*trace.LimitExceededError)) + } + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + + _ = func() { + // static assertion that waiting 1000 periods should bring us to LimiterBurst but no higher + const mustBeTrue = 1000*defaults.LimiterAverage > defaults.LimiterBurst + _ = map[bool]struct{}{false: struct{}{}, mustBeTrue: struct{}{}} + } + time.Sleep(1000 * defaults.LimiterPeriod) + + for range defaults.LimiterBurst { + _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.NotErrorAs(t, err, new(*trace.LimitExceededError)) + } + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + require.ErrorAs(t, err, new(*trace.LimitExceededError)) + + // no time has passed, but we can do a full burst and more if we pretend to + // have a user context (the request will fail, but not due to the rate + // limiter) + + for range defaults.LimiterBurst + 1 { + _, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ + ContextUser: &proto.ContextUser{}, + }, }) + require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } } From 7cc10e7e83342ebde10a3f3e82dd7f82b60b02a3 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 12 Mar 2026 16:39:25 +0100 Subject: [PATCH 05/11] Hide rate limit details from error response --- lib/auth/grpcserver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index ae33ce8ef4642..ec826a4c361e0 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -4845,7 +4845,7 @@ func (g *GRPCServer) CreateAuthenticateChallenge(ctx context.Context, req *authp } if err := g.createAuthenticateChallengeUnauthenticatedLimiter.RegisterRequestFromAddr(peerInfo.Addr, nil); err != nil { - return nil, trace.Wrap(err) + return nil, trace.LimitExceeded("rate limit exceeded") } } From bad17b4cfd0b06da4945a71217ad8ddcda7a204a Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 00:16:02 +0100 Subject: [PATCH 06/11] Make the limiter configurable for TestAdminActionMFA --- lib/auth/grpcserver.go | 28 +++++++++++++++++---------- lib/auth/middleware.go | 6 ++++++ lib/service/service.go | 17 +++++++++------- lib/service/servicecfg/auth.go | 5 +++++ tool/tctl/common/admin_action_test.go | 2 ++ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index ec826a4c361e0..9e9dcaca72be3 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -230,12 +230,12 @@ type GRPCServer struct { // of the RPC and pulling one out before returning. createAuditStreamSemaphore chan struct{} - // createAuthenticateChallengeUnauthenticatedLimiter is a rate limiter for - // invocations of /proto.AuthService/CreateAuthenticateChallenge that don't - // rely on a user context and thus warrant additional rate limiting since - // they are unauthenticated (either through direct API connections or coming - // from the proxy on behalf of a remote unauthenticated user). - createAuthenticateChallengeUnauthenticatedLimiter *limiter.RateLimiter + // createAuthenticateChallengeLimiter is a rate limiter for invocations of + // /proto.AuthService/CreateAuthenticateChallenge that don't rely on a user + // context and thus warrant additional rate limiting since they are + // unauthenticated (either through direct API connections or coming from the + // proxy on behalf of a remote unauthenticated user). + createAuthenticateChallengeLimiter *limiter.RateLimiter } func (g *GRPCServer) SetServingStatus(service string, servingStatus grpc_health_v1.HealthCheckResponse_ServingStatus) { @@ -4844,7 +4844,7 @@ func (g *GRPCServer) CreateAuthenticateChallenge(ctx context.Context, req *authp return nil, trace.BadParameter("unable to find peer") } - if err := g.createAuthenticateChallengeUnauthenticatedLimiter.RegisterRequestFromAddr(peerInfo.Addr, nil); err != nil { + if err := g.createAuthenticateChallengeLimiter.RegisterRequestFromAddr(peerInfo.Addr, nil); err != nil { return nil, trace.LimitExceeded("rate limit exceeded") } } @@ -5964,6 +5964,10 @@ type GRPCServerConfig struct { UnaryInterceptors []grpc.UnaryServerInterceptor // StreamInterceptors is the gRPC stream interceptor chain. StreamInterceptors []grpc.StreamServerInterceptor + // CreateAuthenticateChallengeLimiterConfig is the optional configuration + // for the limiter applied to unauthenticated calls to + // CreateAuthenticateChallenge. Used in tests. + CreateAuthenticateChallengeLimiterConfig *limiter.Config } // CheckAndSetDefaults checks and sets default values @@ -6245,13 +6249,17 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { } grpcv1pb.RegisterServiceConfigDiscoveryServiceServer(server, grpcClientConfigService) - createAuthenticateChallengeUnauthenticatedLimiter, err := limiter.NewRateLimiter(limiter.Config{ + limiterConfig := limiter.Config{ Rates: []limiter.Rate{{ Period: defaults.LimiterPeriod, Average: defaults.LimiterAverage, Burst: defaults.LimiterBurst, }}, - }) + } + if cfg.CreateAuthenticateChallengeLimiterConfig != nil { + limiterConfig = *cfg.CreateAuthenticateChallengeLimiterConfig + } + createAuthenticateChallengeLimiter, err := limiter.NewRateLimiter(limiterConfig) if err != nil { return nil, trace.Wrap(err) } @@ -6262,7 +6270,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { server: server, healthcheck: health.NewServer(), - createAuthenticateChallengeUnauthenticatedLimiter: createAuthenticateChallengeUnauthenticatedLimiter, + createAuthenticateChallengeLimiter: createAuthenticateChallengeLimiter, } if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" { diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 3722725af76ae..5b84edc970383 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -77,6 +77,10 @@ type TLSServerConfig struct { APIConfig // LimiterConfig is limiter config LimiterConfig limiter.Config + // CreateAuthenticateChallengeLimiterConfig is the optional configuration + // for the limiter applied to unauthenticated calls to + // CreateAuthenticateChallenge. Used in tests. + CreateAuthenticateChallengeLimiterConfig *limiter.Config // AccessPoint is a caching access point AccessPoint AccessCacheWithEvents // Component is used for debugging purposes @@ -234,6 +238,8 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error) APIConfig: cfg.APIConfig, UnaryInterceptors: authMiddleware.UnaryInterceptors(), StreamInterceptors: authMiddleware.StreamInterceptors(), + + CreateAuthenticateChallengeLimiterConfig: cfg.CreateAuthenticateChallengeLimiterConfig, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/service/service.go b/lib/service/service.go index 37eaa077fdcf9..75e99848650ef 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2758,13 +2758,16 @@ func (process *TeleportProcess) initAuthService() error { TLS: tlsConfig, GetClientCertificate: connector.ClientGetCertificate, - APIConfig: *apiConf, - LimiterConfig: cfg.Auth.Limiter, - AccessPoint: authServer.Cache, - Component: teleport.Component(teleport.ComponentAuth, process.id), - ID: process.id, - Listener: mux.TLS(), - Metrics: authMetrics, + APIConfig: *apiConf, + + LimiterConfig: cfg.Auth.Limiter, + CreateAuthenticateChallengeLimiterConfig: cfg.Auth.CreateAuthenticateChallengeLimiterConfig, + + AccessPoint: authServer.Cache, + Component: teleport.Component(teleport.ComponentAuth, process.id), + ID: process.id, + Listener: mux.TLS(), + Metrics: authMetrics, }) if err != nil { return trace.Wrap(err) diff --git a/lib/service/servicecfg/auth.go b/lib/service/servicecfg/auth.go index 038906ae2bbbf..d86ac40090b89 100644 --- a/lib/service/servicecfg/auth.go +++ b/lib/service/servicecfg/auth.go @@ -80,6 +80,11 @@ type AuthConfig struct { Limiter limiter.Config + // CreateAuthenticateChallengeLimiterConfig is the optional configuration + // for the limiter applied to unauthenticated calls to + // CreateAuthenticateChallenge. Used in tests. + CreateAuthenticateChallengeLimiterConfig *limiter.Config + // NoAudit, when set to true, disables session recording and event audit NoAudit bool diff --git a/tool/tctl/common/admin_action_test.go b/tool/tctl/common/admin_action_test.go index 0f287f7cf8207..10686f871a3e4 100644 --- a/tool/tctl/common/admin_action_test.go +++ b/tool/tctl/common/admin_action_test.go @@ -50,6 +50,7 @@ import ( libclient "github.com/gravitational/teleport/lib/client" libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/modules/modulestest" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -1077,6 +1078,7 @@ func newAdminActionTestSuite(t *testing.T) *adminActionTestSuite { proxyPublicAddr = cfg.Proxy.WebAddr proxyPublicAddr.Addr = fmt.Sprintf("localhost:%v", proxyPublicAddr.Port(0)) cfg.Proxy.PublicAddrs = []utils.NetAddr{proxyPublicAddr} + cfg.Auth.CreateAuthenticateChallengeLimiterConfig = &limiter.Config{} }), ) require.NoError(t, err) From f74fb7dba3f5c48b727ce469c4c44f6128bed0eb Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 19:04:47 +0100 Subject: [PATCH 07/11] Use a regular require.Greater check instead of a static assertion --- lib/auth/grpcserver_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index e556d48a76bd5..98d2f755f1de6 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -4743,11 +4743,8 @@ func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *tes _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) require.ErrorAs(t, err, new(*trace.LimitExceededError)) - _ = func() { - // static assertion that waiting 1000 periods should bring us to LimiterBurst but no higher - const mustBeTrue = 1000*defaults.LimiterAverage > defaults.LimiterBurst - _ = map[bool]struct{}{false: struct{}{}, mustBeTrue: struct{}{}} - } + // waiting 1000 periods should bring us to LimiterBurst but no higher + require.Greater(t, 1000*defaults.LimiterAverage, defaults.LimiterBurst) time.Sleep(1000 * defaults.LimiterPeriod) for range defaults.LimiterBurst { From c2ef38698a018cb55194477d4f663aea9a3abb66 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 19:08:14 +0100 Subject: [PATCH 08/11] Update lib/auth/grpcserver_test.go Co-authored-by: Alan Parra <12500300+codingllama@users.noreply.github.com> --- lib/auth/grpcserver_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 98d2f755f1de6..340778112b0bc 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -4743,8 +4743,8 @@ func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *tes _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) require.ErrorAs(t, err, new(*trace.LimitExceededError)) - // waiting 1000 periods should bring us to LimiterBurst but no higher - require.Greater(t, 1000*defaults.LimiterAverage, defaults.LimiterBurst) + require.Greater(t, 1000*defaults.LimiterAverage, defaults.LimiterBurst, + "Waiting 1000 periods should bring us to LimiterBurst but no higher") time.Sleep(1000 * defaults.LimiterPeriod) for range defaults.LimiterBurst { From 1089d3e0353a569c8087c41626dbc895c3424ee9 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 19:42:52 +0100 Subject: [PATCH 09/11] Treat a nil request as a ContextUser request --- lib/auth/grpcserver.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 9e9dcaca72be3..e9791db5e63da 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -4834,8 +4834,9 @@ func (g *GRPCServer) CreateAuthenticateChallenge(ctx context.Context, req *authp return nil, trace.Wrap(err) } - if req.GetContextUser() == nil { - // requests with a user context will be checked for legitimacy by + if req.GetContextUser() == nil && req.GetRequest() != nil { + // requests with a user context (or with an empty request, which is + // considered to be the same) will be checked for legitimacy by // ServerWithRoles later, so we only need to care about adding an // additional rate limit for the non-user-context requests here From e9965f59a09b429b9a0cbe3aef74912833b933f5 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 19:46:20 +0100 Subject: [PATCH 10/11] Undo changes to make the limiter configurable --- lib/auth/grpcserver.go | 12 ++---------- lib/auth/middleware.go | 6 ------ lib/service/service.go | 17 +++++++---------- lib/service/servicecfg/auth.go | 5 ----- tool/tctl/common/admin_action_test.go | 2 -- 5 files changed, 9 insertions(+), 33 deletions(-) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index e9791db5e63da..7c0ad70bdb7fd 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5965,10 +5965,6 @@ type GRPCServerConfig struct { UnaryInterceptors []grpc.UnaryServerInterceptor // StreamInterceptors is the gRPC stream interceptor chain. StreamInterceptors []grpc.StreamServerInterceptor - // CreateAuthenticateChallengeLimiterConfig is the optional configuration - // for the limiter applied to unauthenticated calls to - // CreateAuthenticateChallenge. Used in tests. - CreateAuthenticateChallengeLimiterConfig *limiter.Config } // CheckAndSetDefaults checks and sets default values @@ -6250,17 +6246,13 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { } grpcv1pb.RegisterServiceConfigDiscoveryServiceServer(server, grpcClientConfigService) - limiterConfig := limiter.Config{ + createAuthenticateChallengeLimiter, err := limiter.NewRateLimiter(limiter.Config{ Rates: []limiter.Rate{{ Period: defaults.LimiterPeriod, Average: defaults.LimiterAverage, Burst: defaults.LimiterBurst, }}, - } - if cfg.CreateAuthenticateChallengeLimiterConfig != nil { - limiterConfig = *cfg.CreateAuthenticateChallengeLimiterConfig - } - createAuthenticateChallengeLimiter, err := limiter.NewRateLimiter(limiterConfig) + }) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 5b84edc970383..3722725af76ae 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -77,10 +77,6 @@ type TLSServerConfig struct { APIConfig // LimiterConfig is limiter config LimiterConfig limiter.Config - // CreateAuthenticateChallengeLimiterConfig is the optional configuration - // for the limiter applied to unauthenticated calls to - // CreateAuthenticateChallenge. Used in tests. - CreateAuthenticateChallengeLimiterConfig *limiter.Config // AccessPoint is a caching access point AccessPoint AccessCacheWithEvents // Component is used for debugging purposes @@ -238,8 +234,6 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error) APIConfig: cfg.APIConfig, UnaryInterceptors: authMiddleware.UnaryInterceptors(), StreamInterceptors: authMiddleware.StreamInterceptors(), - - CreateAuthenticateChallengeLimiterConfig: cfg.CreateAuthenticateChallengeLimiterConfig, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/service/service.go b/lib/service/service.go index 75e99848650ef..37eaa077fdcf9 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2758,16 +2758,13 @@ func (process *TeleportProcess) initAuthService() error { TLS: tlsConfig, GetClientCertificate: connector.ClientGetCertificate, - APIConfig: *apiConf, - - LimiterConfig: cfg.Auth.Limiter, - CreateAuthenticateChallengeLimiterConfig: cfg.Auth.CreateAuthenticateChallengeLimiterConfig, - - AccessPoint: authServer.Cache, - Component: teleport.Component(teleport.ComponentAuth, process.id), - ID: process.id, - Listener: mux.TLS(), - Metrics: authMetrics, + APIConfig: *apiConf, + LimiterConfig: cfg.Auth.Limiter, + AccessPoint: authServer.Cache, + Component: teleport.Component(teleport.ComponentAuth, process.id), + ID: process.id, + Listener: mux.TLS(), + Metrics: authMetrics, }) if err != nil { return trace.Wrap(err) diff --git a/lib/service/servicecfg/auth.go b/lib/service/servicecfg/auth.go index d86ac40090b89..038906ae2bbbf 100644 --- a/lib/service/servicecfg/auth.go +++ b/lib/service/servicecfg/auth.go @@ -80,11 +80,6 @@ type AuthConfig struct { Limiter limiter.Config - // CreateAuthenticateChallengeLimiterConfig is the optional configuration - // for the limiter applied to unauthenticated calls to - // CreateAuthenticateChallenge. Used in tests. - CreateAuthenticateChallengeLimiterConfig *limiter.Config - // NoAudit, when set to true, disables session recording and event audit NoAudit bool diff --git a/tool/tctl/common/admin_action_test.go b/tool/tctl/common/admin_action_test.go index 10686f871a3e4..0f287f7cf8207 100644 --- a/tool/tctl/common/admin_action_test.go +++ b/tool/tctl/common/admin_action_test.go @@ -50,7 +50,6 @@ import ( libclient "github.com/gravitational/teleport/lib/client" libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/cryptosuites" - "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/modules/modulestest" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -1078,7 +1077,6 @@ func newAdminActionTestSuite(t *testing.T) *adminActionTestSuite { proxyPublicAddr = cfg.Proxy.WebAddr proxyPublicAddr.Addr = fmt.Sprintf("localhost:%v", proxyPublicAddr.Port(0)) cfg.Proxy.PublicAddrs = []utils.NetAddr{proxyPublicAddr} - cfg.Auth.CreateAuthenticateChallengeLimiterConfig = &limiter.Config{} }), ) require.NoError(t, err) From 24e73f6023a13a1f3eb1a80353118ad831b06968 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 13 Mar 2026 20:01:39 +0100 Subject: [PATCH 11/11] fix TestCustomRateLimiting --- lib/auth/grpcserver_test.go | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 340778112b0bc..c433e351ab2f9 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -4723,24 +4723,35 @@ func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *tes require.NoError(t, err) defer clt.Close() + nonContextUserRequest := &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_UserCredentials{ + UserCredentials: new(proto.UserCredentials), + }, + } + contextUserRequest := &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ + ContextUser: new(proto.ContextUser), + }, + } + for range defaults.LimiterBurst { - _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err := clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) time.Sleep(defaults.LimiterPeriod) for range defaults.LimiterBurst - defaults.LimiterAverage { - _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err := clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) require.Greater(t, 1000*defaults.LimiterAverage, defaults.LimiterBurst, @@ -4748,12 +4759,12 @@ func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *tes time.Sleep(1000 * defaults.LimiterPeriod) for range defaults.LimiterBurst { - _, err := clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err := clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) - _, err = clt.CreateAuthenticateChallenge(ctx, new(proto.CreateAuthenticateChallengeRequest)) + _, err = clt.CreateAuthenticateChallenge(ctx, nonContextUserRequest) require.ErrorAs(t, err, new(*trace.LimitExceededError)) // no time has passed, but we can do a full burst and more if we pretend to @@ -4761,11 +4772,7 @@ func synctestCustomRateLimitingUnauthenticatedCreateAuthenticateChallenge(t *tes // limiter) for range defaults.LimiterBurst + 1 { - _, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ - Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ - ContextUser: &proto.ContextUser{}, - }, - }) + _, err := clt.CreateAuthenticateChallenge(ctx, contextUserRequest) require.NotErrorAs(t, err, new(*trace.LimitExceededError)) } }