Skip to content

Commit f564fe4

Browse files
markandrusmislav
andauthored
Support Google "OAuth 2.0 for TV and Limited-Input Device Applications" (#24)
Co-authored-by: Mislav Marohnić <[email protected]>
1 parent 612a4a5 commit f564fe4

File tree

4 files changed

+141
-49
lines changed

4 files changed

+141
-49
lines changed

device/device_flow.go

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ type CodeResponse struct {
5151
// The minimum number of seconds that must pass before you can make a new access token request to
5252
// complete the device authorization.
5353
Interval int
54-
55-
timeNow func() time.Time
56-
timeSleep func(time.Duration)
5754
}
5855

5956
// RequestCode initiates the authorization flow by requesting a code from uri.
@@ -67,6 +64,10 @@ func RequestCode(c httpClient, uri string, clientID string, scopes []string) (*C
6764
}
6865

6966
verificationURI := resp.Get("verification_uri")
67+
if verificationURI == "" {
68+
// Google's "OAuth 2.0 for TV and Limited-Input Device Applications" uses `verification_url`.
69+
verificationURI = resp.Get("verification_url")
70+
}
7071

7172
if resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode == 404 || resp.StatusCode == 422 ||
7273
(resp.StatusCode == 200 && verificationURI == "") ||
@@ -98,30 +99,66 @@ func RequestCode(c httpClient, uri string, clientID string, scopes []string) (*C
9899
}, nil
99100
}
100101

101-
const grantType = "urn:ietf:params:oauth:grant-type:device_code"
102+
const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code"
102103

103104
// PollToken polls the server at pollURL until an access token is granted or denied.
105+
//
106+
// Deprecated: use PollTokenWithOptions.
104107
func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) {
105-
timeNow := code.timeNow
108+
return PollTokenWithOptions(c, pollURL, PollOptions{
109+
ClientID: clientID,
110+
DeviceCode: code,
111+
})
112+
}
113+
114+
// PollOptions specifies parameters to poll the server with until authentication completes.
115+
type PollOptions struct {
116+
// ClientID is the app client ID value.
117+
ClientID string
118+
// ClientSecret is the app client secret value. Optional: only pass if the server requires it.
119+
ClientSecret string
120+
// DeviceCode is the value obtained from RequestCode.
121+
DeviceCode *CodeResponse
122+
// GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional.
123+
GrantType string
124+
125+
timeNow func() time.Time
126+
timeSleep func(time.Duration)
127+
}
128+
129+
// PollTokenWithOptions polls the server at uri until authorization completes.
130+
func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) {
131+
timeNow := opts.timeNow
106132
if timeNow == nil {
107133
timeNow = time.Now
108134
}
109-
timeSleep := code.timeSleep
135+
timeSleep := opts.timeSleep
110136
if timeSleep == nil {
111137
timeSleep = time.Sleep
112138
}
113139

114-
checkInterval := time.Duration(code.Interval) * time.Second
115-
expiresAt := timeNow().Add(time.Duration(code.ExpiresIn) * time.Second)
140+
checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
141+
expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second)
142+
grantType := opts.GrantType
143+
if opts.GrantType == "" {
144+
grantType = defaultGrantType
145+
}
116146

117147
for {
118148
timeSleep(checkInterval)
119149

120-
resp, err := api.PostForm(c, pollURL, url.Values{
121-
"client_id": {clientID},
122-
"device_code": {code.DeviceCode},
150+
values := url.Values{
151+
"client_id": {opts.ClientID},
152+
"device_code": {opts.DeviceCode.DeviceCode},
123153
"grant_type": {grantType},
124-
})
154+
}
155+
156+
// Google's "OAuth 2.0 for TV and Limited-Input Device Applications" requires `client_secret`.
157+
if opts.ClientSecret != "" {
158+
values.Add("client_secret", opts.ClientSecret)
159+
}
160+
161+
resp, err := api.PostForm(c, uri, values)
125162
if err != nil {
126163
return nil, err
127164
}

device/device_flow_test.go

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,9 @@ func TestPollToken(t *testing.T) {
249249
}
250250

251251
type args struct {
252-
http apiClient
253-
url string
254-
clientID string
255-
code *CodeResponse
252+
http apiClient
253+
url string
254+
opts PollOptions
256255
}
257256
tests := []struct {
258257
name string
@@ -279,16 +278,18 @@ func TestPollToken(t *testing.T) {
279278
},
280279
},
281280
},
282-
url: "https://github.com/oauth",
283-
clientID: "CLIENT-ID",
284-
code: &CodeResponse{
285-
DeviceCode: "DEVIC",
286-
UserCode: "123-abc",
287-
VerificationURI: "http://verify.me",
288-
ExpiresIn: 99,
289-
Interval: 5,
290-
timeSleep: mockSleep,
291-
timeNow: clock("0", "5s", "10s"),
281+
url: "https://github.com/oauth",
282+
opts: PollOptions{
283+
ClientID: "CLIENT-ID",
284+
DeviceCode: &CodeResponse{
285+
DeviceCode: "DEVIC",
286+
UserCode: "123-abc",
287+
VerificationURI: "http://verify.me",
288+
ExpiresIn: 99,
289+
Interval: 5,
290+
},
291+
timeSleep: mockSleep,
292+
timeNow: clock("0", "5s", "10s"),
292293
},
293294
},
294295
want: &api.AccessToken{
@@ -314,6 +315,50 @@ func TestPollToken(t *testing.T) {
314315
},
315316
},
316317
},
318+
{
319+
name: "with client secret and grant type",
320+
args: args{
321+
http: apiClient{
322+
stubs: []apiStub{
323+
{
324+
body: "access_token=123abc",
325+
status: 200,
326+
contentType: "application/x-www-form-urlencoded; charset=utf-8",
327+
},
328+
},
329+
},
330+
url: "https://github.com/oauth",
331+
opts: PollOptions{
332+
ClientID: "CLIENT-ID",
333+
ClientSecret: "SEKRIT",
334+
GrantType: "device_code",
335+
DeviceCode: &CodeResponse{
336+
DeviceCode: "DEVIC",
337+
UserCode: "123-abc",
338+
VerificationURI: "http://verify.me",
339+
ExpiresIn: 99,
340+
Interval: 5,
341+
},
342+
timeSleep: mockSleep,
343+
timeNow: clock("0", "5s", "10s"),
344+
},
345+
},
346+
want: &api.AccessToken{
347+
Token: "123abc",
348+
},
349+
slept: duration("5s"),
350+
posts: []postArgs{
351+
{
352+
url: "https://github.com/oauth",
353+
params: url.Values{
354+
"client_id": {"CLIENT-ID"},
355+
"client_secret": {"SEKRIT"},
356+
"device_code": {"DEVIC"},
357+
"grant_type": {"device_code"},
358+
},
359+
},
360+
},
361+
},
317362
{
318363
name: "timed out",
319364
args: args{
@@ -331,16 +376,18 @@ func TestPollToken(t *testing.T) {
331376
},
332377
},
333378
},
334-
url: "https://github.com/oauth",
335-
clientID: "CLIENT-ID",
336-
code: &CodeResponse{
337-
DeviceCode: "DEVIC",
338-
UserCode: "123-abc",
339-
VerificationURI: "http://verify.me",
340-
ExpiresIn: 99,
341-
Interval: 5,
342-
timeSleep: mockSleep,
343-
timeNow: clock("0", "5s", "15m"),
379+
url: "https://github.com/oauth",
380+
opts: PollOptions{
381+
ClientID: "CLIENT-ID",
382+
DeviceCode: &CodeResponse{
383+
DeviceCode: "DEVIC",
384+
UserCode: "123-abc",
385+
VerificationURI: "http://verify.me",
386+
ExpiresIn: 99,
387+
Interval: 5,
388+
},
389+
timeSleep: mockSleep,
390+
timeNow: clock("0", "5s", "15m"),
344391
},
345392
},
346393
wantErr: "authentication timed out",
@@ -376,16 +423,18 @@ func TestPollToken(t *testing.T) {
376423
},
377424
},
378425
},
379-
url: "https://github.com/oauth",
380-
clientID: "CLIENT-ID",
381-
code: &CodeResponse{
382-
DeviceCode: "DEVIC",
383-
UserCode: "123-abc",
384-
VerificationURI: "http://verify.me",
385-
ExpiresIn: 99,
386-
Interval: 5,
387-
timeSleep: mockSleep,
388-
timeNow: clock("0", "5s"),
426+
url: "https://github.com/oauth",
427+
opts: PollOptions{
428+
ClientID: "CLIENT-ID",
429+
DeviceCode: &CodeResponse{
430+
DeviceCode: "DEVIC",
431+
UserCode: "123-abc",
432+
VerificationURI: "http://verify.me",
433+
ExpiresIn: 99,
434+
Interval: 5,
435+
},
436+
timeSleep: mockSleep,
437+
timeNow: clock("0", "5s"),
389438
},
390439
},
391440
wantErr: "access_denied",
@@ -405,7 +454,7 @@ func TestPollToken(t *testing.T) {
405454
for _, tt := range tests {
406455
t.Run(tt.name, func(t *testing.T) {
407456
totalSlept = 0
408-
got, err := PollToken(&tt.args.http, tt.args.url, tt.args.clientID, tt.args.code)
457+
got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts)
409458
if (err != nil) != (tt.wantErr != "") {
410459
t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr)
411460
return

device/examples_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ func Example() {
2222
fmt.Printf("Copy code: %s\n", code.UserCode)
2323
fmt.Printf("then open: %s\n", code.VerificationURI)
2424

25-
accessToken, err := PollToken(httpClient, "https://github.com/login/oauth/access_token", clientID, code)
25+
accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{
26+
ClientID: clientID,
27+
DeviceCode: code,
28+
})
2629
if err != nil {
2730
panic(err)
2831
}

oauth_device.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
5858
return nil, fmt.Errorf("error opening the web browser: %w", err)
5959
}
6060

61-
return device.PollToken(httpClient, host.TokenURL, oa.ClientID, code)
61+
return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{
62+
ClientID: oa.ClientID,
63+
DeviceCode: code,
64+
})
6265
}
6366

6467
func waitForEnter(r io.Reader) error {

0 commit comments

Comments
 (0)