@@ -13,6 +13,7 @@ import (
1313 "testing"
1414
1515 "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler"
16+ "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
1617 "github.com/cockroachdb/cockroach/pkg/util/leaktest"
1718 "github.com/jackc/pgproto3/v2"
1819 "github.com/stretchr/testify/assert"
@@ -98,14 +99,14 @@ func TestAuthenticateClearText(t *testing.T) {
9899func TestAuthenticateThrottled (t * testing.T ) {
99100 defer leaktest .AfterTest (t )()
100101
101- server := func (t * testing.T , be * pgproto3.Backend , authResponse pgproto3. BackendMessage ) {
102+ server := func (t * testing.T , be * pgproto3.Backend ) {
102103 require .NoError (t , be .Send (& pgproto3.AuthenticationCleartextPassword {}))
103104
104105 msg , err := be .Receive ()
105106 require .NoError (t , err )
106107 require .Equal (t , msg , & pgproto3.PasswordMessage {Password : "password" })
107108
108- require .NoError (t , be .Send (authResponse ))
109+ require .NoError (t , be .Send (& pgproto3. ErrorResponse { Message : "wrong password" } ))
109110 }
110111
111112 client := func (t * testing.T , fe * pgproto3.Frontend ) {
@@ -130,44 +131,79 @@ func TestAuthenticateThrottled(t *testing.T) {
130131 require .Error (t , err )
131132 }
132133
133- type testCase struct {
134- name string
135- result pgproto3.BackendMessage
136- expectedStatus throttler.AttemptStatus
137- }
138- for _ , tc := range []testCase {
139- {
140- name : "AuthenticationOkay" ,
141- result : & pgproto3.AuthenticationOk {},
142- expectedStatus : throttler .AttemptOK ,
143- },
144- {
145- name : "AuthenticationError" ,
146- result : & pgproto3.ErrorResponse {Message : "wrong password" },
147- expectedStatus : throttler .AttemptInvalidCredentials ,
148- },
149- } {
150- t .Run (tc .name , func (t * testing.T ) {
151- proxyToServer , serverToProxy := net .Pipe ()
152- proxyToClient , clientToProxy := net .Pipe ()
153- sqlServer := pgproto3 .NewBackend (pgproto3 .NewChunkReader (serverToProxy ), serverToProxy )
154- sqlClient := pgproto3 .NewFrontend (pgproto3 .NewChunkReader (clientToProxy ), clientToProxy )
155-
156- go server (t , sqlServer , & pgproto3.AuthenticationOk {})
157- go client (t , sqlClient )
158-
159- _ , err := authenticate (proxyToClient , proxyToServer , nil , /* proxyBackendKeyData */
160- func (status throttler.AttemptStatus ) error {
161- require .Equal (t , throttler .AttemptOK , status )
162- return throttledError
163- })
164- require .Error (t , err )
165- require .Contains (t , err .Error (), "connection attempt throttled" )
166-
167- proxyToServer .Close ()
168- proxyToClient .Close ()
134+ proxyToServer , serverToProxy := net .Pipe ()
135+ proxyToClient , clientToProxy := net .Pipe ()
136+ sqlServer := pgproto3 .NewBackend (pgproto3 .NewChunkReader (serverToProxy ), serverToProxy )
137+ sqlClient := pgproto3 .NewFrontend (pgproto3 .NewChunkReader (clientToProxy ), clientToProxy )
138+
139+ go server (t , sqlServer )
140+ go client (t , sqlClient )
141+
142+ _ , err := authenticate (proxyToClient , proxyToServer , nil , /* proxyBackendKeyData */
143+ func (status throttler.AttemptStatus ) error {
144+ require .Equal (t , throttler .AttemptInvalidCredentials , status )
145+ return throttledError
169146 })
147+ require .Error (t , err )
148+ require .Contains (t , err .Error (), "connection attempt throttled" )
149+
150+ proxyToServer .Close ()
151+ proxyToClient .Close ()
152+ }
153+
154+ func TestErrorFollowingAuthenticateNotThrottled (t * testing.T ) {
155+ defer leaktest .AfterTest (t )()
156+
157+ server := func (t * testing.T , be * pgproto3.Backend ) {
158+ require .NoError (t , be .Send (& pgproto3.AuthenticationCleartextPassword {}))
159+
160+ msg , err := be .Receive ()
161+ require .NoError (t , err )
162+ require .Equal (t , msg , & pgproto3.PasswordMessage {Password : "password" })
163+
164+ require .NoError (t , be .Send (& pgproto3.ErrorResponse {
165+ Code : pgcode .TooManyConnections .String (),
166+ Message : "sorry, too many clients already" }))
170167 }
168+
169+ client := func (t * testing.T , fe * pgproto3.Frontend ) {
170+ msg , err := fe .Receive ()
171+ require .NoError (t , err )
172+ require .Equal (t , msg , & pgproto3.AuthenticationCleartextPassword {})
173+
174+ require .NoError (t , fe .Send (& pgproto3.PasswordMessage {Password : "password" }))
175+
176+ // Try reading from the connection. This check ensures authorize
177+ // swallowed the OK/Error response from the sql server.
178+ msg , err = fe .Receive ()
179+ require .NoError (t , err )
180+ require .Equal (t , msg , & pgproto3.ErrorResponse {
181+ Code : pgcode .TooManyConnections .String (),
182+ Message : "sorry, too many clients already" })
183+ }
184+
185+ proxyToServer , serverToProxy := net .Pipe ()
186+ proxyToClient , clientToProxy := net .Pipe ()
187+ sqlServer := pgproto3 .NewBackend (pgproto3 .NewChunkReader (serverToProxy ), serverToProxy )
188+ sqlClient := pgproto3 .NewFrontend (pgproto3 .NewChunkReader (clientToProxy ), clientToProxy )
189+
190+ go server (t , sqlServer )
191+ go client (t , sqlClient )
192+
193+ var throttleCallCount int
194+ _ , err := authenticate (proxyToClient , proxyToServer , nil , /* proxyBackendKeyData */
195+ func (status throttler.AttemptStatus ) error {
196+ throttleCallCount ++
197+ require .Equal (t , throttler .AttemptOK , status )
198+ return nil
199+ })
200+ require .Equal (t , 1 , throttleCallCount )
201+ require .Error (t , err )
202+ require .Contains (t , err .Error (),
203+ "codeAuthFailed: authentication failed: sorry, too many clients already" )
204+
205+ proxyToServer .Close ()
206+ proxyToClient .Close ()
171207}
172208
173209func TestAuthenticateError (t * testing.T ) {
0 commit comments