Skip to content

Commit b90da1c

Browse files
committed
basic run search. More tests needed.
1 parent ea4db3e commit b90da1c

File tree

4 files changed

+160
-67
lines changed

4 files changed

+160
-67
lines changed

src/MLFlowClient.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export
2525
MLFlowRunStatus,
2626
MLFlowRunInfo,
2727
MLFlowRunData,
28+
MLFlowRunDataMetric,
2829
MLFlowRun
2930

3031
include("utils.jl")
@@ -41,7 +42,8 @@ export
4142
createrun,
4243
getrun,
4344
updaterun,
44-
deleterun
45+
deleterun,
46+
searchruns
4547

4648
include("logging.jl")
4749
export

src/runs.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,48 @@ end
9595
deleterun(mlf::MLFlow, run_info::MLFlowRunInfo) = deleterun(mlf, run_info.run_id)
9696
deleterun(mlf::MLFlow, run::MLFlowRun) = deleterun(mlf, run.info)
9797

98+
"""
99+
searchruns(mlf::MLFlow, experiment_ids, filter)
100+
101+
Searches for runs in an experiment based on filter.
102+
103+
# Arguments
104+
- `mlf`: [`MLFlow`](@ref) configuration.
105+
- `experiment_ids::AbstractVector{Integer}`: `experiment_id`s in which to search for runs.
106+
107+
# Keywords
108+
- `filter::String`: filter as defined in [MLFlow documentation](https://mlflow.org/docs/latest/rest-api.html#search-runs)
109+
- `run_view_type::String`: ...
110+
- `max_results::Integer`: ...
111+
- `order_by::String`: ...
112+
113+
# Returns
114+
- a vector of runs that were found
115+
116+
"""
117+
function searchruns(mlf::MLFlow, experiment_ids::AbstractVector{<:Integer};
118+
filter::String="",
119+
run_view_type::String="ACTIVE_ONLY",
120+
max_results::Int64=50000,
121+
order_by::AbstractVector{<:String}=[""]
122+
)
123+
endpoint = "runs/search"
124+
run_view_type ["ACTIVE_ONLY", "DELETED_ONLY", "ALL"] || error("Unsupported run_view_type = $run_view_type")
125+
kwargs = (
126+
experiment_ids=experiment_ids,
127+
filter=filter,
128+
run_view_type=run_view_type,
129+
max_results=max_results,
130+
)
131+
if order_by != [""]
132+
kwargs.order_by = order_by
133+
end
134+
135+
result = mlfpost(mlf, endpoint; kwargs...)
136+
haskey(result, "runs") || error("Malformed result from MLFow")
137+
138+
map(x -> MLFlowRun(x["info"], x["data"]), result["runs"])
139+
end
140+
function searchruns(mlf::MLFlow, experiment_id::Integer; kwargs...)
141+
searchruns(mlf, [experiment_id]; kwargs...)
142+
end

src/types.jl

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,59 @@ struct MLFlowRunInfo
127127
end
128128
end
129129

130+
"""
131+
MLFlowRunDataMetric
132+
133+
Represents a metric.
134+
135+
# Fields
136+
- `key::String`: ...
137+
- `value`: ...
138+
- `step::Int64`: ...
139+
- `timestamp::Int64`: ...
140+
"""
141+
struct MLFlowRunDataMetric
142+
key::String
143+
value::Float64
144+
step::Int64
145+
timestamp::Int64
146+
function MLFlowRunDataMetric(d::Dict{String,Any})
147+
key = d["key"]
148+
value = d["value"]
149+
step = parse(Int64, d["step"])
150+
timestamp = parse(Int64, d["timestamp"])
151+
new(key, value, step, timestamp)
152+
end
153+
end
154+
155+
130156
"""
131157
MLFlowRunData
132158
133159
Represents run data.
134160
135161
# Fields
136-
- `metrics`
137-
- `params`
162+
- `metrics::Vector{MLFlowRunDataMetric}`: run metrics.
163+
- `params::Dict{String,String}`: run parameters.
138164
- `tags`
139165
140-
# TODO
141-
Incomplete functionality.
142-
143166
"""
144167
struct MLFlowRunData
145-
metrics
146-
params
168+
metrics::Vector{MLFlowRunDataMetric}
169+
params::Union{Dict{String,String},Missing}
147170
tags
148171
function MLFlowRunData(data::Dict{String,Any})
149-
new([], [], []) # TODO: add functionality
172+
metrics = haskey(data, "metrics") ? MLFlowRunDataMetric.(data["metrics"]) : MLFlowRunDataMetric[]
173+
if haskey(data, "params")
174+
params = Dict{String,String}()
175+
for p in data["params"]
176+
params[p["key"]] = p["value"]
177+
end
178+
else
179+
params = Dict{String,String}()
180+
end
181+
tags = haskey(data, "tags") ? data["tags"] : missing
182+
new(metrics, params, tags)
150183
end
151184
end
152185

@@ -158,11 +191,16 @@ Represents an MLFlow run.
158191
# Fields
159192
- `info::MLFlowRunInfo`: Run metadata.
160193
- `data::MLFlowRunData`: Run data.
194+
161195
"""
162196
struct MLFlowRun
163-
info::MLFlowRunInfo
197+
info::Union{MLFlowRunInfo,Missing}
164198
data::Union{MLFlowRunData,Missing}
165199

200+
function MLFlowRun(rundata::MLFlowRunData)
201+
info = missing
202+
new(info, rundata)
203+
end
166204
function MLFlowRun(runinfo::MLFlowRunInfo)
167205
data = missing
168206
new(runinfo, data)

test/runtests.jl

Lines changed: 65 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,63 +18,71 @@ end
1818
@test mlf.baseuri == mlflowbaseuri
1919
@test mlf.apiversion == 2.0
2020

21-
if mlflow_server_is_running(mlf)
22-
23-
exptags = [:key => "val"]
24-
expname = "expname-$(UUIDs.uuid4())"
25-
26-
@test ismissing(getexperiment(mlf, "$(UUIDs.uuid4()) - $(UUIDs.uuid4())"))
27-
28-
experiment_id = createexperiment(mlf; name=expname, tags=exptags)
29-
experiment = getexperiment(mlf, experiment_id)
30-
@test experiment.experiment_id == experiment_id
31-
experimentbyname = getexperiment(mlf, expname)
32-
@test experimentbyname.name == experiment.name
33-
34-
35-
exprun = createrun(mlf, experiment_id)
36-
@test exprun.info.experiment_id == experiment_id
37-
@test exprun.info.lifecycle_stage == "active"
38-
@test exprun.info.status == MLFlowRunStatus("RUNNING")
39-
exprunid = exprun.info.run_id
40-
41-
logparam(mlf, exprunid, "paramkey", "paramval")
42-
logparam(mlf, exprunid, Dict("k" => "v", "k1" => "v1"))
43-
logparam(mlf, exprun, Dict("test1" => "test2"))
44-
45-
logmetric(mlf, exprun, "metrickeyrun", 1.0)
46-
logmetric(mlf, exprun.info, "metrickeyrun", 2.0)
47-
logmetric(mlf, exprun.info, "metrickeyrun", [2.5, 3.5])
48-
logmetric(mlf, exprunid, "metrickey", 1.0)
49-
logmetric(mlf, exprunid, "metrickey2", [1.0, 1.5, 2.0])
50-
51-
retrieved_run = getrun(mlf, exprunid)
52-
@test exprun.info == retrieved_run.info
53-
54-
tmpfiletoupload = tempname()
55-
f = open(tmpfiletoupload, "w")
56-
write(f, "samplecontents")
57-
close(f)
58-
logartifact(mlf, retrieved_run, tmpfiletoupload)
59-
rm(tmpfiletoupload)
21+
if !mlflow_server_is_running(mlf)
22+
return nothing
23+
end
6024

61-
running_run = updaterun(mlf, exprunid, "RUNNING")
62-
@test running_run.info.experiment_id == experiment_id
63-
@test running_run.info.status == MLFlowRunStatus("RUNNING")
64-
finished_run = updaterun(mlf, exprun, MLFlowRunStatus("FINISHED"))
65-
finishedrun = getrun(mlf, finished_run.info.run_id)
25+
exptags = [:key => "val"]
26+
expname = "expname-$(UUIDs.uuid4())"
27+
28+
@test ismissing(getexperiment(mlf, "$(UUIDs.uuid4()) - $(UUIDs.uuid4())"))
29+
30+
experiment_id = createexperiment(mlf; name=expname, tags=exptags)
31+
experiment = getexperiment(mlf, experiment_id)
32+
@test experiment.experiment_id == experiment_id
33+
experimentbyname = getexperiment(mlf, expname)
34+
@test experimentbyname.name == experiment.name
35+
36+
exprun = createrun(mlf, experiment_id)
37+
@test exprun.info.experiment_id == experiment_id
38+
@test exprun.info.lifecycle_stage == "active"
39+
@test exprun.info.status == MLFlowRunStatus("RUNNING")
40+
exprunid = exprun.info.run_id
41+
42+
logparam(mlf, exprunid, "paramkey", "paramval")
43+
logparam(mlf, exprunid, Dict("k" => "v", "k1" => "v1"))
44+
logparam(mlf, exprun, Dict("test1" => "test2"))
45+
46+
logmetric(mlf, exprun, "metrickeyrun", 1.0)
47+
logmetric(mlf, exprun.info, "metrickeyrun", 2.0)
48+
logmetric(mlf, exprun.info, "metrickeyrun", [2.5, 3.5])
49+
logmetric(mlf, exprunid, "metrickey", 1.0)
50+
logmetric(mlf, exprunid, "metrickey2", [1.0, 1.5, 2.0])
51+
52+
retrieved_run = getrun(mlf, exprunid)
53+
@test exprun.info == retrieved_run.info
54+
55+
tmpfiletoupload = tempname()
56+
f = open(tmpfiletoupload, "w")
57+
write(f, "samplecontents")
58+
close(f)
59+
logartifact(mlf, retrieved_run, tmpfiletoupload)
60+
rm(tmpfiletoupload)
61+
62+
running_run = updaterun(mlf, exprunid, "RUNNING")
63+
@test running_run.info.experiment_id == experiment_id
64+
@test running_run.info.status == MLFlowRunStatus("RUNNING")
65+
finished_run = updaterun(mlf, exprun, MLFlowRunStatus("FINISHED"))
66+
finishedrun = getrun(mlf, finished_run.info.run_id)
6667

67-
# NOTE: seems like MLFlow API never returns `end_time` as documented in https://mlflow.org/docs/latest/rest-api.html#runinfo
68-
# Consider raising an issue with MLFlow itself.
69-
@test_broken !ismissing(finishedrun.info.end_time)
70-
71-
runs = searchrun(mlf, experiment_id, "params.\"paramkey\" == \"paramval\"")
72-
73-
deleterun(mlf, exprunid)
74-
75-
deleteexperiment(mlf, experiment_id)
76-
experiment = getexperiment(mlf, experiment_id)
77-
@test experiment.experiment_id == experiment_id
78-
@test experiment.lifecycle_stage == "deleted"
79-
end
68+
# NOTE: seems like MLFlow API never returns `end_time` as documented in https://mlflow.org/docs/latest/rest-api.html#runinfo
69+
# Consider raising an issue with MLFlow itself.
70+
@test_broken !ismissing(finishedrun.info.end_time)
71+
72+
exprun2 = createrun(mlf, experiment_id)
73+
exprun2id = exprun.info.run_id
74+
logparam(mlf, exprun2, "param2", "key2")
75+
logmetric(mlf, exprun2, "metric2", [1.0, 2.0])
76+
updaterun(mlf, exprun2, "FINISHED")
77+
78+
@show experiment_id
79+
runs = searchruns(mlf, experiment_id)
80+
@test length(runs) == 2
81+
# , "params.\"paramkey\" == \"paramval\"")
82+
# deleterun(mlf, exprunid)
83+
84+
# deleteexperiment(mlf, experiment_id)
85+
# experiment = getexperiment(mlf, experiment_id)
86+
# @test experiment.experiment_id == experiment_id
87+
# @test experiment.lifecycle_stage == "deleted"
8088
end

0 commit comments

Comments
 (0)