Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["@deyandyankov, @pebeto, and contributors"]
version = "0.6.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Expand All @@ -12,6 +13,7 @@ URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
Base64 = "1.0"
HTTP = "1.0"
JSON = "0.21"
ShowCases = "0.1"
Expand Down
1 change: 1 addition & 0 deletions src/MLFlowClient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module MLFlowClient
using Dates
using UUIDs
using HTTP
using Base64
using URIs
using JSON
using ShowCases
Expand Down
57 changes: 50 additions & 7 deletions src/types/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@ Base type which defines location and version for MLFlow API service.
# Fields
- `apiroot::String`: API root URL, e.g. `http://localhost:5000/api`
- `apiversion::Union{Integer, AbstractFloat}`: used API version, e.g. `2.0`
- `headers::Dict`: HTTP headers to be provided with the REST API requests (useful for
authetication tokens) Default is `false`, using the REST API endpoint.
- `headers::Dict`: HTTP headers to be provided with the REST API requests.
- `username::Union{Nothing, String}`: username for basic authentication.
- `password::Union{Nothing, String}`: password for basic authentication.

!!! warning
You cannot provide an `Authorization` header when an `username` and `password` are
provided. An error will be thrown in that case.

!!! note
- If `MLFLOW_TRACKING_URI` is set, the provided `apiroot` will be ignored.
- If `MLFLOW_TRACKING_USERNAME` is set, the provided `username` will be ignored.
- If `MLFLOW_TRACKING_PASSWORD` is set, the provided `password` will be ignored.
These indications will be displayed as warnings.

# Examples

Expand All @@ -19,17 +30,49 @@ mlf = MLFlow()
remote_url="https://<your-server>.cloud.databricks.com"; # address of your remote server
mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer <your-secret-token>"))
```

"""
struct MLFlow
apiroot::String
apiversion::AbstractFloat
headers::Dict
username::Union{Nothing,String}
password::Union{Nothing,String}

function MLFlow(apiroot, apiversion, headers, username, password)
if haskey(ENV, "MLFLOW_TRACKING_URI")
@warn "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set."
apiroot = ENV["MLFLOW_TRACKING_URI"]
end

if haskey(ENV, "MLFLOW_TRACKING_USERNAME")
@warn "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set."
username = ENV["MLFLOW_TRACKING_USERNAME"]
end

if haskey(ENV, "MLFLOW_TRACKING_PASSWORD")
@warn "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set."
password = ENV["MLFLOW_TRACKING_PASSWORD"]
end

if username |> !isnothing && password |> !isnothing
if haskey(headers, "Authorization")
error("You cannot provide an Authorization header when an username and password are provided.")
end
encoded_credentials = Base64.base64encode("$(username):$(password)")
headers =
merge(headers, Dict("Authorization" => "Basic $(encoded_credentials)"))
end
new(apiroot, apiversion, headers, username, password)
end
end
MLFlow(apiroot; apiversion=2.0, headers=Dict()) = MLFlow(apiroot, apiversion, headers)
MLFlow(; apiroot="http://localhost:5000/api", apiversion=2.0, headers=Dict()) =
MLFlow((haskey(ENV, "MLFLOW_TRACKING_URI") ?
ENV["MLFLOW_TRACKING_URI"] : apiroot), apiversion, headers)
MLFlow(apiroot::String; apiversion::AbstractFloat=2.0, headers::Dict=Dict(),
username::Union{Nothing,String}=nothing,
password::Union{Nothing,String}=nothing)::MLFlow =
MLFlow(apiroot, apiversion, headers, username, password)
MLFlow(; apiroot::String="http://localhost:5000/api", apiversion::AbstractFloat=2.0,
headers::Dict=Dict(), username::Union{Nothing,String}=nothing,
password::Union{Nothing,String}=nothing)::MLFlow =
MLFlow(apiroot, apiversion, headers, username, password)

Base.show(io::IO, t::MLFlow) =
show(io, ShowCase(t, [:apiroot, :apiversion], new_lines=true))
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ end

include("base.jl")

include("types/mlflow.jl")

include("services/run.jl")
include("services/misc.jl")
include("services/logger.jl")
Expand Down
9 changes: 8 additions & 1 deletion test/services/user.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ end
updateuserpassword(getmlfinstance(encoded_credentials), "missy", "ana")
encoded_credentials = Base64.base64encode("$(user.username):ana")

@test_nowarn searchexperiments(getmlfinstance(encoded_credentials))
@test begin
try
searchexperiments(getmlfinstance(encoded_credentials))
true
catch
false
end
end
deleteuser(mlf, user.username)
end

Expand Down
74 changes: 74 additions & 0 deletions test/types/mlflow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
@testset verbose = true "instantiate mlflow" begin
mlflow_tracking_uri = ENV["MLFLOW_TRACKING_URI"]

@testset "using default constructor" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow("test", 2.0, Dict(), nothing, nothing)

@test instance.apiroot == "test"
@test instance.apiversion == 2.0
@test instance.headers == Dict()
@test isnothing(instance.username)
@test isnothing(instance.password)

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using apiroot-only constructor" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow("test")

@test instance.apiroot == "test"
@test instance.apiversion == 2.0
@test instance.headers == Dict()
@test isnothing(instance.username)
@test isnothing(instance.password)

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using constructor with keyword arguments" begin
delete!(ENV, "MLFLOW_TRACKING_URI")

instance = MLFlow(; username="test", password="test")

@test instance.apiroot == "http://localhost:5000/api"
@test instance.apiversion == 2.0
@test haskey(instance.headers, "Authorization")
@test instance.username == "test"
@test instance.password == "test"

ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
end

@testset "using env variables" begin
mlflow_tracking_username =
haskey(ENV, "MLFLOW_TRACKING_USERNAME") ? ENV["MLFLOW_TRACKING_USERNAME"] : nothing
mlflow_tracking_password =
haskey(ENV, "MLFLOW_TRACKING_PASSWORD") ? ENV["MLFLOW_TRACKING_PASSWORD"] : nothing

ENV["MLFLOW_TRACKING_USERNAME"] = "test"
ENV["MLFLOW_TRACKING_PASSWORD"] = "test"

@test_logs (:warn, "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set.") (:warn, "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set.") (:warn, "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set.") MLFlow()

if !isnothing(mlflow_tracking_username)
ENV["MLFLOW_TRACKING_USERNAME"] = mlflow_tracking_username
else
delete!(ENV, "MLFLOW_TRACKING_USERNAME")
end
if !isnothing(mlflow_tracking_password)
ENV["MLFLOW_TRACKING_PASSWORD"] = mlflow_tracking_password
else
delete!(ENV, "MLFLOW_TRACKING_PASSWORD")
end
end

@testset "defining username, password and authorization header" begin
encoded_credentials = Base64.base64encode("test:test")
@test_throws ErrorException MLFlow(; username="test", password="test",
headers=Dict("Authorization" => "Basic $encoded_credentials"))
end
end
Loading