@@ -10,29 +10,48 @@ import (
1010 "time"
1111)
1212
13+ type serverResult struct {
14+ storage * TokenStorage
15+ err error
16+ }
17+
1318// startCallbackServerAsync starts the callback server in a goroutine and
14- // returns a channel that will receive the authorization code (or error string).
15- func startCallbackServerAsync (t * testing.T , port int , state string ) chan string {
19+ // returns a channel that will receive the final result (storage or error).
20+ func startCallbackServerAsync (
21+ t * testing.T ,
22+ port int ,
23+ state string ,
24+ exchangeFn func (ctx context.Context , code string ) (* TokenStorage , error ),
25+ ) chan serverResult {
1626 t .Helper ()
17- ch := make (chan string , 1 )
27+ ch := make (chan serverResult , 1 )
1828 go func () {
19- code , err := startCallbackServer (context .Background (), port , state )
20- if err != nil {
21- ch <- "ERROR:" + err .Error ()
22- } else {
23- ch <- code
24- }
29+ storage , err := startCallbackServer (context .Background (), port , state , exchangeFn )
30+ ch <- serverResult {storage : storage , err : err }
2531 }()
2632 // Give the server a moment to bind.
2733 time .Sleep (50 * time .Millisecond )
2834 return ch
2935}
3036
37+ // mockExchangeFn returns an exchangeFn that succeeds with a stub TokenStorage.
38+ func mockExchangeFn (t * testing.T ) func (ctx context.Context , code string ) (* TokenStorage , error ) {
39+ t .Helper ()
40+ return func (_ context.Context , _ string ) (* TokenStorage , error ) {
41+ return & TokenStorage {
42+ AccessToken : "mock-access-token" ,
43+ RefreshToken : "mock-refresh-token" ,
44+ TokenType : "Bearer" ,
45+ ExpiresAt : time .Now ().Add (time .Hour ),
46+ }, nil
47+ }
48+ }
49+
3150func TestCallbackServer_Success (t * testing.T ) {
3251 const port = 19001
3352 state := "test-state-success"
3453
35- ch := startCallbackServerAsync (t , port , state )
54+ ch := startCallbackServerAsync (t , port , state , mockExchangeFn ( t ) )
3655
3756 // Simulate the browser redirect.
3857 callbackURL := fmt .Sprintf (
@@ -53,11 +72,57 @@ func TestCallbackServer_Success(t *testing.T) {
5372 t .Errorf ("expected success page, got: %s" , string (body ))
5473 }
5574
56- // Check code returned to CLI.
75+ // Check that storage is returned to the CLI.
76+ select {
77+ case result := <- ch :
78+ if result .err != nil {
79+ t .Errorf ("expected no error, got: %v" , result .err )
80+ }
81+ if result .storage == nil || result .storage .AccessToken != "mock-access-token" {
82+ t .Errorf ("unexpected storage: %+v" , result .storage )
83+ }
84+ case <- time .After (3 * time .Second ):
85+ t .Fatal ("timed out waiting for callback result" )
86+ }
87+ }
88+
89+ func TestCallbackServer_ExchangeFailure (t * testing.T ) {
90+ const port = 19006
91+ state := "test-state-exchange-fail"
92+
93+ failFn := func (_ context.Context , _ string ) (* TokenStorage , error ) {
94+ return nil , fmt .Errorf ("server returned status 400: invalid_grant" )
95+ }
96+ ch := startCallbackServerAsync (t , port , state , failFn )
97+
98+ callbackURL := fmt .Sprintf (
99+ "http://127.0.0.1:%d/callback?code=badcode&state=%s" ,
100+ port , state ,
101+ )
102+ resp , err := http .Get (callbackURL ) //nolint:noctx,gosec
103+ if err != nil {
104+ t .Fatalf ("GET callback failed: %v" , err )
105+ }
106+ defer resp .Body .Close ()
107+
108+ body , _ := io .ReadAll (resp .Body )
109+ if resp .StatusCode != http .StatusOK {
110+ t .Errorf ("unexpected status %d" , resp .StatusCode )
111+ }
112+ if ! strings .Contains (string (body ), "Authorization Failed" ) {
113+ t .Errorf ("expected failure page, got: %s" , string (body ))
114+ }
115+ if ! strings .Contains (string (body ), "invalid_grant" ) {
116+ t .Errorf ("expected error detail in page, got: %s" , string (body ))
117+ }
118+
57119 select {
58120 case result := <- ch :
59- if result != "mycode123" {
60- t .Errorf ("expected code mycode123, got: %s" , result )
121+ if result .err == nil {
122+ t .Error ("expected an error, got nil" )
123+ }
124+ if result .storage != nil {
125+ t .Errorf ("expected nil storage, got: %+v" , result .storage )
61126 }
62127 case <- time .After (3 * time .Second ):
63128 t .Fatal ("timed out waiting for callback result" )
@@ -68,7 +133,7 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
68133 const port = 19002
69134 state := "expected-state"
70135
71- ch := startCallbackServerAsync (t , port , state )
136+ ch := startCallbackServerAsync (t , port , state , nil )
72137
73138 callbackURL := fmt .Sprintf (
74139 "http://127.0.0.1:%d/callback?code=mycode&state=wrong-state" ,
@@ -87,8 +152,8 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
87152
88153 select {
89154 case result := <- ch :
90- if ! strings . HasPrefix ( result , "ERROR:" ) {
91- t .Errorf ("expected error for state mismatch, got: %s" , result )
155+ if result . err == nil {
156+ t .Errorf ("expected error for state mismatch, got nil" )
92157 }
93158 case <- time .After (3 * time .Second ):
94159 t .Fatal ("timed out waiting for callback result" )
@@ -99,7 +164,7 @@ func TestCallbackServer_OAuthError(t *testing.T) {
99164 const port = 19003
100165 state := "state-for-error"
101166
102- ch := startCallbackServerAsync (t , port , state )
167+ ch := startCallbackServerAsync (t , port , state , nil )
103168
104169 callbackURL := fmt .Sprintf (
105170 "http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s" ,
@@ -118,11 +183,11 @@ func TestCallbackServer_OAuthError(t *testing.T) {
118183
119184 select {
120185 case result := <- ch :
121- if ! strings . HasPrefix ( result , "ERROR:" ) {
122- t .Errorf ("expected error for access_denied, got: %s" , result )
186+ if result . err == nil {
187+ t .Errorf ("expected error for access_denied, got nil" )
123188 }
124- if ! strings .Contains (result , "access_denied" ) {
125- t .Errorf ("expected error to mention access_denied, got: %s " , result )
189+ if ! strings .Contains (result . err . Error () , "access_denied" ) {
190+ t .Errorf ("expected error to mention access_denied, got: %v " , result . err )
126191 }
127192 case <- time .After (3 * time .Second ):
128193 t .Fatal ("timed out waiting for callback result" )
@@ -137,7 +202,7 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
137202 const port = 19005
138203 state := "test-state-double"
139204
140- ch := startCallbackServerAsync (t , port , state )
205+ ch := startCallbackServerAsync (t , port , state , mockExchangeFn ( t ) )
141206
142207 url := fmt .Sprintf ("http://127.0.0.1:%d/callback?code=mycode&state=%s" , port , state )
143208
@@ -163,11 +228,14 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
163228 }
164229 }
165230
166- // startCallbackServer must also return promptly.
231+ // startCallbackServer must also return promptly with a valid storage .
167232 select {
168233 case result := <- ch :
169- if result != "mycode" {
170- t .Errorf ("expected mycode, got: %s" , result )
234+ if result .err != nil {
235+ t .Errorf ("expected no error, got: %v" , result .err )
236+ }
237+ if result .storage == nil {
238+ t .Error ("expected non-nil storage" )
171239 }
172240 case <- time .After (3 * time .Second ):
173241 t .Fatal ("timed out waiting for callback result" )
@@ -178,7 +246,7 @@ func TestCallbackServer_MissingCode(t *testing.T) {
178246 const port = 19004
179247 state := "state-for-missing-code"
180248
181- ch := startCallbackServerAsync (t , port , state )
249+ ch := startCallbackServerAsync (t , port , state , nil )
182250
183251 // Correct state but no code parameter.
184252 callbackURL := fmt .Sprintf (
@@ -193,8 +261,8 @@ func TestCallbackServer_MissingCode(t *testing.T) {
193261
194262 select {
195263 case result := <- ch :
196- if ! strings . HasPrefix ( result , "ERROR:" ) {
197- t .Errorf ("expected error for missing code, got: %s" , result )
264+ if result . err == nil {
265+ t .Errorf ("expected error for missing code, got nil" )
198266 }
199267 case <- time .After (3 * time .Second ):
200268 t .Fatal ("timed out waiting for callback result" )
0 commit comments