Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["@deyandyankov, @pebeto, and contributors"]
version = "0.6.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Expand All @@ -12,6 +13,7 @@ URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
Base64 = "1.11.0"
HTTP = "1.0"
JSON = "0.21"
ShowCases = "0.1"
Expand Down
1 change: 1 addition & 0 deletions src/MLFlowClient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module MLFlowClient
using Dates
using UUIDs
using HTTP
using Base64
using URIs
using JSON
using ShowCases
Expand Down
57 changes: 50 additions & 7 deletions src/types/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@ Base type which defines location and version for MLFlow API service.
# Fields
- `apiroot::String`: API root URL, e.g. `http://localhost:5000/api`
- `apiversion::Union{Integer, AbstractFloat}`: used API version, e.g. `2.0`
- `headers::Dict`: HTTP headers to be provided with the REST API requests (useful for
authetication tokens) Default is `false`, using the REST API endpoint.
- `headers::Dict`: HTTP headers to be provided with the REST API requests.
- `username::Union{Nothing, String}`: username for basic authentication.
- `password::Union{Nothing, String}`: password for basic authentication.

!!! warning
You cannot provide an `Authorization` header when an `username` and `password` are
provided. An error will be thrown in that case.

!!! note
- If `MLFLOW_TRACKING_URI` is set, the provided `apiroot` will be ignored.
- If `MLFLOW_TRACKING_USERNAME` is set, the provided `username` will be ignored.
- If `MLFLOW_TRACKING_PASSWORD` is set, the provided `password` will be ignored.
These indications will be displayed as warnings.

# Examples

Expand All @@ -19,17 +30,49 @@ mlf = MLFlow()
remote_url="https://<your-server>.cloud.databricks.com"; # address of your remote server
mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer <your-secret-token>"))
```

"""
struct MLFlow
apiroot::String
apiversion::AbstractFloat
headers::Dict
username::Union{Nothing,String}
password::Union{Nothing,String}

function MLFlow(apiroot, apiversion, headers, username, password)
if haskey(ENV, "MLFLOW_TRACKING_URI")
@warn "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set."
apiroot = ENV["MLFLOW_TRACKING_URI"]
end

if haskey(ENV, "MLFLOW_TRACKING_USERNAME")
@warn "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set."
username = ENV["MLFLOW_TRACKING_USERNAME"]
end

if haskey(ENV, "MLFLOW_TRACKING_PASSWORD")
@warn "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set."
password = ENV["MLFLOW_TRACKING_PASSWORD"]
end

if username |> !isnothing && password |> !isnothing
if haskey(headers, "Authorization")
error("You cannot provide an Authorization header when an username and password are provided.")
end
encoded_credentials = Base64.base64encode("$(username):$(password)")
headers =
merge(headers, Dict("Authorization" => "Basic $(encoded_credentials)"))
end
new(apiroot, apiversion, headers, username, password)
end
end
MLFlow(apiroot; apiversion=2.0, headers=Dict()) = MLFlow(apiroot, apiversion, headers)
MLFlow(; apiroot="http://localhost:5000/api", apiversion=2.0, headers=Dict()) =
MLFlow((haskey(ENV, "MLFLOW_TRACKING_URI") ?
ENV["MLFLOW_TRACKING_URI"] : apiroot), apiversion, headers)
MLFlow(apiroot::String; apiversion::AbstractFloat=2.0, headers::Dict=Dict(),
username::Union{Nothing,String}=nothing,
password::Union{Nothing,String}=nothing)::MLFlow =
MLFlow(apiroot, apiversion, headers, username, password)
MLFlow(; apiroot::String="http://localhost:5000/api", apiversion::AbstractFloat=2.0,
headers::Dict=Dict(), username::Union{Nothing,String}=nothing,
password::Union{Nothing,String}=nothing)::MLFlow =
MLFlow(apiroot, apiversion, headers, username, password)

Base.show(io::IO, t::MLFlow) =
show(io, ShowCase(t, [:apiroot, :apiversion], new_lines=true))
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ end

include("base.jl")

include("types/mlflow.jl")

include("services/run.jl")
include("services/misc.jl")
include("services/logger.jl")
Expand Down
9 changes: 8 additions & 1 deletion test/services/user.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ end
updateuserpassword(getmlfinstance(encoded_credentials), "missy", "ana")
encoded_credentials = Base64.base64encode("$(user.username):ana")

@test_nowarn searchexperiments(getmlfinstance(encoded_credentials))
@test begin
try
searchexperiments(getmlfinstance(encoded_credentials))
true
catch
false
end
end
deleteuser(mlf, user.username)
end

Expand Down
74 changes: 74 additions & 0 deletions test/types/mlflow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
@testset verbose = true "instantiate mlflow" begin
mlflow_tracking_uri = ENV["MLFLOW_TRACKING_URI"]

@testset "using default constructor" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow("test", 2.0, Dict(), nothing, nothing)

@test instance.apiroot == "test"
@test instance.apiversion == 2.0
@test instance.headers == Dict()
@test isnothing(instance.username)
@test isnothing(instance.password)

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using apiroot-only constructor" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow("test")

@test instance.apiroot == "test"
@test instance.apiversion == 2.0
@test instance.headers == Dict()
@test isnothing(instance.username)
@test isnothing(instance.password)

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using constructor with keyword arguments" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow(; username="test", password="test")

@test instance.apiroot == "http://localhost:5000/api"
@test instance.apiversion == 2.0
@test haskey(instance.headers, "Authorization")
@test instance.username == "test"
@test instance.password == "test"

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using env variables" begin
mlflow_tracking_username =
haskey(ENV, "MLFLOW_TRACKING_USERNAME") ? ENV["MLFLOW_TRACKING_USERNAME"] : nothing
mlflow_tracking_password =
haskey(ENV, "MLFLOW_TRACKING_PASSWORD") ? ENV["MLFLOW_TRACKING_PASSWORD"] : nothing

ENV["MLFLOW_TRACKING_USERNAME"] = "test"
ENV["MLFLOW_TRACKING_PASSWORD"] = "test"

@test_logs (:warn, "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set.") (:warn, "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set.") (:warn, "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set.") MLFlow()

if !isnothing(mlflow_tracking_username)
ENV["MLFLOW_TRACKING_USERNAME"] = mlflow_tracking_username
else
delete!(ENV, "MLFLOW_TRACKING_USERNAME")
end
if !isnothing(mlflow_tracking_password)
ENV["MLFLOW_TRACKING_PASSWORD"] = mlflow_tracking_password
else
delete!(ENV, "MLFLOW_TRACKING_PASSWORD")
end
end

@testset "defining username, password and authorization header" begin
encoded_credentials = Base64.base64encode("test:test")
@test_throws ErrorException MLFlow(; username="test", password="test",
headers=Dict("Authorization" => "Basic $encoded_credentials"))
end
end
Loading