Skip to content

Commit 2ac6325

Browse files
committed
Use custom config endpoint
1 parent 192157a commit 2ac6325

File tree

3 files changed

+92
-53
lines changed

3 files changed

+92
-53
lines changed

src/PkgAuthentication.jl

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -183,48 +183,66 @@ function should_use_device_auth()
183183
return !isempty(get_device_auth_client_id())
184184
end
185185

186-
function get_openid_configuration(state::NoAuthentication)
186+
# Query the /auth/configuration endpoint to get the refresh url and
187+
# device authentication endpoints. Returns a Dict with the following
188+
# fields:
189+
# - `device_flow_supported`::Bool: Indicates whether device flow is
190+
# enabled on the server.
191+
# - `refresh_url`::String: The refresh URL for refreshing the auth
192+
# token
193+
# - `device_authorization_endpoint`::String: The endpoint that must
194+
# be called to initiate device flow authentication. This field is
195+
# only present when device flow is enabled on the server.
196+
# - `token_endpoint`::String: The endpoint that should be called to
197+
# retrieve the authentication token after the user has approved
198+
# the authorization request. This field is only present when device
199+
# flow is enabled on the server.
200+
function get_auth_configuration(state::NoAuthentication)
187201
output = IOBuffer()
202+
auth_suffix = isempty(state.auth_suffix) ? "auth" : state.auth_suffix
188203
response = Downloads.request(
189-
joinpath(state.server, ".well-known/openid-configuration"),
204+
joinpath(state.server, auth_suffix, "configuration"),
190205
method = "GET",
191206
output = output,
192207
throw = false,
193208
headers = ["Accept" => "application/json"],
194209
)
195210

211+
def_resp = Dict{String, Any}(
212+
"device_flow_supported" => false,
213+
"refresh_url" => joinpath(state.server, auth_suffix, "renew/token.toml/v2/")
214+
)
215+
196216
if response isa Downloads.Response && response.status == 200
197217
body = nothing
198218
content = String(take!(output))
199219
try
200220
body = JSON.parse(content)
201221
catch ex
202222
@debug "Request for well known configuration returned: ", content
203-
return false, "", ""
223+
return def_resp
204224
end
205225

206226
if body !== nothing
207-
return true, body["device_authorization_endpoint"], body["token_endpoint"]
227+
@assert haskey(body, "device_flow_supported")
228+
@assert haskey(body, "refresh_url")
229+
@assert (body["device_flow_supported"] && haskey(body, "device_authorization_endpoint") && haskey(body, "token_endpoint")) || !body["device_flow_supported"]
230+
return body
208231
end
209232
end
210233

211-
return false, "", ""
234+
return def_resp
212235
end
213236

214237
function step(state::NoAuthentication)::Union{RequestLogin, Failure}
215-
token_endpoint = ""
216-
device_endpoint = ""
217-
if should_use_device_auth()
218-
s, device_endpoint, token_endpoint = get_openid_configuration(state)
219-
s || GenericError("Unable to get device and token endpoints")
220-
end
221-
success, challenge, body_or_response = if should_use_device_auth()
222-
fetch_device_code(state, device_endpoint)
238+
auth_config = get_auth_configuration(state)
239+
success, challenge, body_or_response = if auth_config["device_flow_supported"]
240+
fetch_device_code(state, auth_config["device_authorization_endpoint"])
223241
else
224242
initiate_browser_challenge(state)
225243
end
226244
if success
227-
return RequestLogin(state.server, state.auth_suffix, challenge, body_or_response, token_endpoint)
245+
return RequestLogin(state.server, state.auth_suffix, challenge, body_or_response, get(auth_config, "token_endpoint", ""), auth_config["refresh_url"])
228246
else
229247
return HttpError(body_or_response)
230248
end
@@ -251,7 +269,6 @@ function fetch_device_code(state::NoAuthentication, device_endpoint::AbstractStr
251269
end
252270

253271
if body !== nothing
254-
body["client"] = "device"
255272
return true, "", body
256273
end
257274
end
@@ -314,7 +331,6 @@ Base.show(io::IO, s::NeedRefresh) = print(io, "NeedRefresh($(s.server), $(s.auth
314331
function step(state::NeedRefresh)::Union{HasNewToken, NoAuthentication}
315332
refresh_token = state.token["refresh_token"]
316333
output = IOBuffer()
317-
is_device = get(state.token, "client", nothing) == "device"
318334
response = Downloads.request(
319335
state.token["refresh_url"],
320336
method = "GET",
@@ -331,17 +347,15 @@ function step(state::NeedRefresh)::Union{HasNewToken, NoAuthentication}
331347
assert_dict_keys(body, "expires_in"; msg=msg)
332348
assert_dict_keys(body, "expires", "expires_at"; msg=msg)
333349
end
334-
if is_device
335-
body["client"] = "device"
336-
# refresh_url and expires/expires_at will be present in this refreshed token
337-
# so no need to manually add them here
338-
end
350+
@info("Successfully refreshed token")
339351
return HasNewToken(state.server, body)
340352
catch err
341353
@debug "invalid body received while refreshing token" exception=(err, catch_backtrace())
342354
end
355+
@info "Did not refresh token, could not json parse ", response
343356
return NoAuthentication(state.server, state.auth_suffix)
344357
else
358+
@info "Did not refresh token, got non 200 response ", response
345359
@debug "request for refreshing token failed" response
346360
return NoAuthentication(state.server, state.auth_suffix)
347361
end
@@ -404,11 +418,12 @@ struct RequestLogin <: State
404418
challenge::String
405419
response::Union{String, Dict{String, Any}}
406420
token_endpoint::String
421+
refresh_url::String
407422
end
408-
Base.show(io::IO, s::RequestLogin) = print(io, "RequestLogin($(s.server), $(s.auth_suffix), <REDACTED>, $(s.response), $(s.token_endpoint))")
423+
Base.show(io::IO, s::RequestLogin) = print(io, "RequestLogin($(s.server), $(s.auth_suffix), <REDACTED>, $(s.response), $(s.token_endpoint), $(s.refresh_url))")
409424

410425
function step(state::RequestLogin)::Union{ClaimToken, Failure}
411-
is_device = state.response isa Dict{String, Any} && get(state.response, "client", nothing) == "device"
426+
is_device = !isempty(state.token_endpoint)
412427
url = if is_device
413428
string(state.response["verification_uri_complete"])
414429
else
@@ -418,9 +433,9 @@ function step(state::RequestLogin)::Union{ClaimToken, Failure}
418433
success = open_browser(url)
419434
if success && is_device
420435
# In case of device tokens, timeout for challenge is received in the initial request.
421-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, Inf, time(), state.response["expires_in"], 2, 0, 10, state.token_endpoint)
436+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, Inf, time(), state.response["expires_in"], 2, 0, 10, state.token_endpoint, state.refresh_url)
422437
elseif success
423-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.token_endpoint)
438+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.token_endpoint, state.refresh_url)
424439
else # this can only happen for the browser hook
425440
return GenericError("Failed to execute open_browser hook.")
426441
end
@@ -443,11 +458,12 @@ struct ClaimToken <: State
443458
failures::Int
444459
max_failures::Int
445460
token_endpoint::String
461+
refresh_url::String
446462
end
447-
Base.show(io::IO, s::ClaimToken) = print(io, "ClaimToken($(s.server), $(s.auth_suffix), <REDACTED>, $(s.response), $(s.expiry), $(s.start_time), $(s.timeout), $(s.poll_interval), $(s.failures), $(s.max_failures), $(s.token_endpoint))")
463+
Base.show(io::IO, s::ClaimToken) = print(io, "ClaimToken($(s.server), $(s.auth_suffix), <REDACTED>, $(s.response), $(s.expiry), $(s.start_time), $(s.timeout), $(s.poll_interval), $(s.failures), $(s.max_failures), $(s.token_endpoint), $(s.refresh_url))")
448464

449-
ClaimToken(server, auth_suffix, challenge, response, token_endpoint, expiry = Inf, failures = 0) =
450-
ClaimToken(server, auth_suffix, challenge, response, expiry, time(), 180, 2, failures, 10, token_endpoint)
465+
ClaimToken(server, auth_suffix, challenge, response, token_endpoint, refresh_url, expiry = Inf, failures = 0) =
466+
ClaimToken(server, auth_suffix, challenge, response, expiry, time(), 180, 2, failures, 10, token_endpoint, refresh_url)
451467

452468
function step(state::ClaimToken)::Union{ClaimToken, HasNewToken, Failure}
453469
if time() > state.expiry || (time() - state.start_time)/1e6 > state.timeout # server-side or client-side timeout
@@ -461,7 +477,7 @@ function step(state::ClaimToken)::Union{ClaimToken, HasNewToken, Failure}
461477
sleep(state.poll_interval)
462478

463479
output = IOBuffer()
464-
is_device = state.response isa Dict{String, Any} && get(state.response, "client", nothing) == "device"
480+
is_device = !isempty(state.token_endpoint)
465481
if is_device
466482
output = IOBuffer()
467483
response = Downloads.request(
@@ -490,25 +506,25 @@ function step(state::ClaimToken)::Union{ClaimToken, HasNewToken, Failure}
490506
body = try
491507
JSON.parse(String(take!(output)))
492508
catch err
493-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint)
509+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint, state.refresh_url)
494510
end
495511

496512
if haskey(body, "token")
497513
return HasNewToken(state.server, body["token"])
498514
elseif haskey(body, "expiry") # time at which the response/challenge pair will expire on the server
499-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, body["expiry"], state.start_time, state.timeout, state.poll_interval, state.failures, state.max_failures, state.token_endpoint)
515+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, body["expiry"], state.start_time, state.timeout, state.poll_interval, state.failures, state.max_failures, state.token_endpoint, state.refresh_url)
500516
else
501-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint)
517+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint, state.refresh_url)
502518
end
503519
elseif response isa Downloads.Response && response.status == 200
504520
body = JSON.parse(String(take!(output)))
505-
body["client"] = "device"
506521
body["expires"] = body["expires_in"] + Int(floor(time()))
507522
body["expires_at"] = body["expires"]
508-
body["refresh_url"] = joinpath(state.server, "auth/renew/token.toml/device/") # Need to be careful with auth suffix, if set
523+
@info("Setting refresh url to ", state.refresh_url)
524+
body["refresh_url"] = state.refresh_url
509525
return HasNewToken(state.server, body)
510526
elseif response isa Downloads.Response && response.status in [401, 400] && is_device
511-
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint)
527+
return ClaimToken(state.server, state.auth_suffix, state.challenge, state.response, state.expiry, state.start_time, state.timeout, state.poll_interval, state.failures + 1, state.max_failures, state.token_endpoint, state.refresh_url)
512528
else
513529
return HttpError(response)
514530
end

test/authserver.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ import TOML
44
const EXPIRY = 30
55
const CHALLENGE_EXPIRY = 10
66
const PORT = 8888
7+
const LEGACY_MODE = 1
8+
const DEVICE_FLOW_MODE = 2
79

810
const ID_TOKEN = Random.randstring(100)
911
const TOKEN = Ref(Dict())
12+
const MODE = Ref(LEGACY_MODE)
1013

1114
challenge_response_map = Dict()
1215
challenge_timeout = Dict()
@@ -99,18 +102,36 @@ function check_validity(req)
99102
return HTTP.Response(200, payload == TOKEN[])
100103
end
101104

105+
function set_mode_legacy(req)
106+
MODE[] = LEGACY_MODE
107+
return HTTP.Response(200)
108+
end
102109

110+
function set_mode_device(req)
111+
MODE[] = DEVICE_FLOW_MODE
112+
return HTTP.Response(200)
113+
end
103114

104-
# --------- Device auth methods -----------------
105-
106-
function openid_configuration(req)
107-
return HTTP.Response(
108-
200,
109-
""" {
110-
"device_authorization_endpoint": "http://localhost:8888/auth/device/code",
111-
"token_endpoint": "http://localhost:8888/auth/token"
112-
} """,
113-
)
115+
function auth_configuration(req)
116+
if MODE[] == LEGACY_MODE
117+
return HTTP.Response(
118+
200,
119+
""" {
120+
"device_flow_supported": false,
121+
"refresh_url": "http://localhost:$PORT/auth/renew/token.toml/v2/"
122+
} """,
123+
)
124+
else
125+
return HTTP.Response(
126+
200,
127+
""" {
128+
"device_flow_supported": true,
129+
"refresh_url": "http://localhost:$PORT/auth/renew/token.toml/device/",
130+
"device_authorization_endpoint": "http://localhost:$PORT/auth/device/code",
131+
"token_endpoint": "http://localhost:$PORT/auth/token"
132+
} """,
133+
)
134+
end
114135
end
115136

116137
device_code_user_code_map = Dict{String, Any}()
@@ -126,7 +147,7 @@ function auth_device_code(req)
126147
""" {
127148
"device_code": "$device_code",
128149
"user_code": "$user_code",
129-
"verification_uri_complete": "http://localhost:8888/auth/device?user_code=$user_code",
150+
"verification_uri_complete": "http://localhost:$PORT/auth/device?user_code=$user_code",
130151
"expires_in": $CHALLENGE_EXPIRY
131152
} """,
132153
)
@@ -141,11 +162,11 @@ function auth_device(req)
141162
end
142163
authenticated[device_code] = true
143164
refresh_token = Random.randstring(10)
144-
TOKEN[]["access_token"] = "full-$ID_TOKEN"
165+
TOKEN[]["access_token"] = "device-$ID_TOKEN"
145166
TOKEN[]["token_type"] = "bearer"
146167
TOKEN[]["expires_in"] = EXPIRY
147168
TOKEN[]["refresh_token"] = refresh_token
148-
TOKEN[]["id_token"] = "full-$ID_TOKEN"
169+
TOKEN[]["id_token"] = "device-$ID_TOKEN"
149170
return HTTP.Response(200)
150171
end
151172

@@ -169,11 +190,13 @@ HTTP.register!(router, "GET", "/auth/response", response_handler)
169190
HTTP.register!(router, "POST", "/auth/claimtoken", claimtoken_handler)
170191
HTTP.register!(router, "GET", "/auth/renew/token.toml/v2", renew_handler)
171192
HTTP.register!(router, "POST", "/auth/isvalid", check_validity)
172-
HTTP.register!(router, "GET", "/.well-known/openid-configuration", openid_configuration)
193+
HTTP.register!(router, "GET", "/auth/configuration", auth_configuration)
173194
HTTP.register!(router, "POST", "/auth/device/code", auth_device_code)
174195
HTTP.register!(router, "GET", "/auth/device", auth_device)
175196
HTTP.register!(router, "POST", "/auth/token", auth_token)
176197
HTTP.register!(router, "GET", "/auth/renew/token.toml/device", renew_handler)
198+
HTTP.register!(router, "POST", "/set_mode/legacy", set_mode_legacy)
199+
HTTP.register!(router, "POST", "/set_mode/device", set_mode_device)
177200

178201
function run()
179202
println("starting server")

test/tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,16 @@ PkgAuthentication.register_open_browser_hook(url -> HTTP.get(url))
6161
@test startswith(success2.token["id_token"], "refresh-")
6262
end
6363

64-
ENV["JULIA_PKG_AUTHENTICATION_DEVICE_CLIENT_ID"] = "device"
65-
6664
@testset "auth with running server (device flow)" begin
6765
delete_token()
66+
HTTP.post(joinpath(test_pkg_server, "set_mode/device"))
6867

6968
@info "testing inital auth"
7069
success = PkgAuthentication.authenticate(test_pkg_server)
7170

7271
@test success isa PkgAuthentication.Success
7372
@test success.token["expires_at"] > time()
74-
@test startswith(success.token["id_token"], "full-")
73+
@test startswith(success.token["id_token"], "device-")
7574
@test !occursin("id_token", sprint(show, success))
7675

7776
sleeptimer = ceil(Int, success.token["expires_at"] - time() + 1)
@@ -85,9 +84,10 @@ ENV["JULIA_PKG_AUTHENTICATION_DEVICE_CLIENT_ID"] = "device"
8584
@test success2.token["expires_at"] > time()
8685
@test success2.token["refresh_token"] !== success.token["refresh_token"]
8786
@test startswith(success2.token["id_token"], "refresh-")
87+
88+
HTTP.post(joinpath(test_pkg_server, "set_mode/legacy"))
8889
end
8990

90-
ENV["JULIA_PKG_AUTHENTICATION_DEVICE_CLIENT_ID"] = ""
9191

9292
@testset "PkgAuthentication.install" begin
9393
delete_token()

0 commit comments

Comments
 (0)