diff --git a/Project.toml b/Project.toml index 08365ae..a27b71b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["@deyandyankov, @pebeto, and contributors"] version = "0.6.0" [deps] +Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" @@ -12,6 +13,7 @@ URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] +Base64 = "1.0" HTTP = "1.0" JSON = "0.21" ShowCases = "0.1" diff --git a/src/MLFlowClient.jl b/src/MLFlowClient.jl index 2a46265..92a41c8 100644 --- a/src/MLFlowClient.jl +++ b/src/MLFlowClient.jl @@ -14,6 +14,7 @@ module MLFlowClient using Dates using UUIDs using HTTP +using Base64 using URIs using JSON using ShowCases diff --git a/src/types/mlflow.jl b/src/types/mlflow.jl index 111232e..9848047 100644 --- a/src/types/mlflow.jl +++ b/src/types/mlflow.jl @@ -6,8 +6,19 @@ Base type which defines location and version for MLFlow API service. # Fields - `apiroot::String`: API root URL, e.g. `http://localhost:5000/api` - `apiversion::Union{Integer, AbstractFloat}`: used API version, e.g. `2.0` -- `headers::Dict`: HTTP headers to be provided with the REST API requests (useful for - authetication tokens) Default is `false`, using the REST API endpoint. +- `headers::Dict`: HTTP headers to be provided with the REST API requests. +- `username::Union{Nothing, String}`: username for basic authentication. +- `password::Union{Nothing, String}`: password for basic authentication. + +!!! warning + You cannot provide an `Authorization` header when an `username` and `password` are + provided. An error will be thrown in that case. + +!!! note + - If `MLFLOW_TRACKING_URI` is set, the provided `apiroot` will be ignored. + - If `MLFLOW_TRACKING_USERNAME` is set, the provided `username` will be ignored. + - If `MLFLOW_TRACKING_PASSWORD` is set, the provided `password` will be ignored. + These indications will be displayed as warnings. # Examples @@ -19,17 +30,49 @@ mlf = MLFlow() remote_url="https://.cloud.databricks.com"; # address of your remote server mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer ")) ``` - """ struct MLFlow apiroot::String apiversion::AbstractFloat headers::Dict + username::Union{Nothing,String} + password::Union{Nothing,String} + + function MLFlow(apiroot, apiversion, headers, username, password) + if haskey(ENV, "MLFLOW_TRACKING_URI") + @warn "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set." + apiroot = ENV["MLFLOW_TRACKING_URI"] + end + + if haskey(ENV, "MLFLOW_TRACKING_USERNAME") + @warn "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set." + username = ENV["MLFLOW_TRACKING_USERNAME"] + end + + if haskey(ENV, "MLFLOW_TRACKING_PASSWORD") + @warn "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set." + password = ENV["MLFLOW_TRACKING_PASSWORD"] + end + + if username |> !isnothing && password |> !isnothing + if haskey(headers, "Authorization") + error("You cannot provide an Authorization header when an username and password are provided.") + end + encoded_credentials = Base64.base64encode("$(username):$(password)") + headers = + merge(headers, Dict("Authorization" => "Basic $(encoded_credentials)")) + end + new(apiroot, apiversion, headers, username, password) + end end -MLFlow(apiroot; apiversion=2.0, headers=Dict()) = MLFlow(apiroot, apiversion, headers) -MLFlow(; apiroot="http://localhost:5000/api", apiversion=2.0, headers=Dict()) = - MLFlow((haskey(ENV, "MLFLOW_TRACKING_URI") ? - ENV["MLFLOW_TRACKING_URI"] : apiroot), apiversion, headers) +MLFlow(apiroot::String; apiversion::AbstractFloat=2.0, headers::Dict=Dict(), + username::Union{Nothing,String}=nothing, + password::Union{Nothing,String}=nothing)::MLFlow = + MLFlow(apiroot, apiversion, headers, username, password) +MLFlow(; apiroot::String="http://localhost:5000/api", apiversion::AbstractFloat=2.0, + headers::Dict=Dict(), username::Union{Nothing,String}=nothing, + password::Union{Nothing,String}=nothing)::MLFlow = + MLFlow(apiroot, apiversion, headers, username, password) Base.show(io::IO, t::MLFlow) = show(io, ShowCase(t, [:apiroot, :apiversion], new_lines=true)) diff --git a/test/runtests.jl b/test/runtests.jl index 6ccb44f..1a0d314 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,8 @@ end include("base.jl") +include("types/mlflow.jl") + include("services/run.jl") include("services/misc.jl") include("services/logger.jl") diff --git a/test/services/user.jl b/test/services/user.jl index 7e81ea8..958a155 100644 --- a/test/services/user.jl +++ b/test/services/user.jl @@ -36,7 +36,14 @@ end updateuserpassword(getmlfinstance(encoded_credentials), "missy", "ana") encoded_credentials = Base64.base64encode("$(user.username):ana") - @test_nowarn searchexperiments(getmlfinstance(encoded_credentials)) + @test begin + try + searchexperiments(getmlfinstance(encoded_credentials)) + true + catch + false + end + end deleteuser(mlf, user.username) end diff --git a/test/types/mlflow.jl b/test/types/mlflow.jl new file mode 100644 index 0000000..ff4fcbd --- /dev/null +++ b/test/types/mlflow.jl @@ -0,0 +1,74 @@ +@testset verbose = true "instantiate mlflow" begin + mlflow_tracking_uri = ENV["MLFLOW_TRACKING_URI"] + + @testset "using default constructor" begin + delete!(ENV, "MLFLOW_TRACKING_URI") + + instance = MLFlow("test", 2.0, Dict(), nothing, nothing) + + @test instance.apiroot == "test" + @test instance.apiversion == 2.0 + @test instance.headers == Dict() + @test isnothing(instance.username) + @test isnothing(instance.password) + + ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri + end + + @testset "using apiroot-only constructor" begin + delete!(ENV, "MLFLOW_TRACKING_URI") + + instance = MLFlow("test") + + @test instance.apiroot == "test" + @test instance.apiversion == 2.0 + @test instance.headers == Dict() + @test isnothing(instance.username) + @test isnothing(instance.password) + + ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri + end + + @testset "using constructor with keyword arguments" begin + delete!(ENV, "MLFLOW_TRACKING_URI") + + instance = MLFlow(; username="test", password="test") + + @test instance.apiroot == "http://localhost:5000/api" + @test instance.apiversion == 2.0 + @test haskey(instance.headers, "Authorization") + @test instance.username == "test" + @test instance.password == "test" + + ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri + end + + @testset "using env variables" begin + mlflow_tracking_username = + haskey(ENV, "MLFLOW_TRACKING_USERNAME") ? ENV["MLFLOW_TRACKING_USERNAME"] : nothing + mlflow_tracking_password = + haskey(ENV, "MLFLOW_TRACKING_PASSWORD") ? ENV["MLFLOW_TRACKING_PASSWORD"] : nothing + + ENV["MLFLOW_TRACKING_USERNAME"] = "test" + ENV["MLFLOW_TRACKING_PASSWORD"] = "test" + + @test_logs (:warn, "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set.") (:warn, "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set.") (:warn, "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set.") MLFlow() + + if !isnothing(mlflow_tracking_username) + ENV["MLFLOW_TRACKING_USERNAME"] = mlflow_tracking_username + else + delete!(ENV, "MLFLOW_TRACKING_USERNAME") + end + if !isnothing(mlflow_tracking_password) + ENV["MLFLOW_TRACKING_PASSWORD"] = mlflow_tracking_password + else + delete!(ENV, "MLFLOW_TRACKING_PASSWORD") + end + end + + @testset "defining username, password and authorization header" begin + encoded_credentials = Base64.base64encode("test:test") + @test_throws ErrorException MLFlow(; username="test", password="test", + headers=Dict("Authorization" => "Basic $encoded_credentials")) + end +end