@@ -12,16 +12,26 @@ import (
1212 "net/http/httptest"
1313 "net/url"
1414 "testing"
15+
16+ "golang.org/x/oauth2"
1517)
1618
17- func newConf (serverURL string ) * Config {
18- return & Config {
19+ func newConf (serverURL string , assertion bool ) * Config {
20+ conf := & Config {
1921 ClientID : "CLIENT_ID" ,
20- ClientSecret : "CLIENT_SECRET" ,
2122 Scopes : []string {"scope1" , "scope2" },
2223 TokenURL : serverURL + "/token" ,
2324 EndpointParams : url.Values {"audience" : {"audience1" }},
25+ AuthStyle : oauth2 .AuthStyleInParams ,
26+ }
27+ if assertion {
28+ conf .ClientAssertionFn = func (ctx context.Context ) (string , error ) {
29+ return "CLIENT_ASSERTION" , nil
30+ }
31+ } else {
32+ conf .ClientSecret = "CLIENT_SECRET"
2433 }
34+ return conf
2535}
2636
2737type mockTransport struct {
@@ -69,45 +79,70 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
6979 }
7080}
7181
82+ func assert (t * testing.T , want , got string ) {
83+ t .Helper ()
84+ if got != want {
85+ t .Errorf ("got %q; want %q" , got , want )
86+ }
87+ }
88+
7289func TestTokenRequest (t * testing.T ) {
7390 ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
7491 if r .URL .String () != "/token" {
7592 t .Errorf ("authenticate client request URL = %q; want %q" , r .URL , "/token" )
7693 }
77- headerAuth := r .Header .Get ("Authorization" )
78- if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
79- t .Errorf ("Unexpected authorization header, %v is found." , headerAuth )
80- }
94+
8195 if got , want := r .Header .Get ("Content-Type" ), "application/x-www-form-urlencoded" ; got != want {
8296 t .Errorf ("Content-Type header = %q; want %q" , got , want )
8397 }
84- body , err := ioutil .ReadAll (r .Body )
85- if err != nil {
86- r .Body .Close ()
87- }
88- if err != nil {
89- t .Errorf ("failed reading request body: %s." , err )
90- }
91- if string (body ) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" {
92- t .Errorf ("payload = %q; want %q" , string (body ), "grant_type=client_credentials&scope=scope1+scope2" )
98+
99+ assert (t , "audience1" , r .FormValue ("audience" ))
100+ assert (t , "CLIENT_ID" , r .FormValue ("client_id" ))
101+ assert (t , "client_credentials" , r .FormValue ("grant_type" ))
102+ assert (t , "scope1 scope2" , r .FormValue ("scope" ))
103+ if r .FormValue ("client_secret" ) != "" {
104+ assert (t , "CLIENT_SECRET" , r .FormValue ("client_secret" ))
105+ } else {
106+ assert (t , "CLIENT_ASSERTION" , r .FormValue ("client_assertion" ))
107+ assert (t , "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" , r .FormValue ("client_assertion_type" ))
93108 }
94109 w .Header ().Set ("Content-Type" , "application/x-www-form-urlencoded" )
95110 w .Write ([]byte ("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer" ))
96111 }))
97112 defer ts .Close ()
98- conf := newConf (ts .URL )
99- tok , err := conf .Token (context .Background ())
100- if err != nil {
101- t .Error (err )
102- }
103- if ! tok .Valid () {
104- t .Fatalf ("token invalid. got: %#v" , tok )
113+
114+ type testCase struct {
115+ name string
116+ conf * Config
105117 }
106- if tok .AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
107- t .Errorf ("Access token = %q; want %q" , tok .AccessToken , "90d64460d14870c08c81352a05dedd3465940a7c" )
118+
119+ tests := []testCase {
120+ {
121+ name : "client id and client_secret" ,
122+ conf : newConf (ts .URL , false ),
123+ },
124+ {
125+ name : "client id and client_assertion" ,
126+ conf : newConf (ts .URL , true ),
127+ },
108128 }
109- if tok .TokenType != "bearer" {
110- t .Errorf ("token type = %q; want %q" , tok .TokenType , "bearer" )
129+
130+ for _ , tc := range tests {
131+ t .Run (tc .name , func (t * testing.T ) {
132+ tok , err := tc .conf .Token (context .Background ())
133+ if err != nil {
134+ t .Error (err )
135+ }
136+ if ! tok .Valid () {
137+ t .Fatalf ("token invalid. got: %#v" , tok )
138+ }
139+ if tok .AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
140+ t .Errorf ("Access token = %q; want %q" , tok .AccessToken , "90d64460d14870c08c81352a05dedd3465940a7c" )
141+ }
142+ if tok .TokenType != "bearer" {
143+ t .Errorf ("token type = %q; want %q" , tok .TokenType , "bearer" )
144+ }
145+ })
111146 }
112147}
113148
@@ -132,7 +167,7 @@ func TestTokenRefreshRequest(t *testing.T) {
132167 io .WriteString (w , `{"access_token": "foo", "refresh_token": "bar"}` )
133168 }))
134169 defer ts .Close ()
135- conf := newConf (ts .URL )
170+ conf := newConf (ts .URL , false )
136171 c := conf .Client (context .Background ())
137172 c .Get (ts .URL + "/somethingelse" )
138173}
0 commit comments