Skip to content

Commit 4fab94b

Browse files
authored
Merge pull request #122 from IBM/precompile_tools
Precompile tools support
2 parents 4c11a2c + 3cc16a6 commit 4fab94b

File tree

4 files changed

+44
-26
lines changed

4 files changed

+44
-26
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
name = "AutoMLPipeline"
22
uuid = "08437348-eef5-4817-bc1b-d4e9459680d6"
33
authors = ["Paulito Palmes <[email protected]>"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
AMLPipelineBase = "e3c3008a-8869-4d53-9f34-c96f99c8a2b6"
88
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
99
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
10+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1011
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314

1415
[compat]
1516
AMLPipelineBase = "0.1"
1617
CondaPkg = "0.2"
17-
DataFrames = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 1.0, 2.0"
18+
DataFrames = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 1"
19+
PrecompileTools = "1"
1820
PythonCall = "0.9"
1921
julia = "1"
2022

src/AutoMLPipeline.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AutoMLPipeline
22

3+
using PrecompileTools: @setup_workload, @compile_workload
34
using AMLPipelineBase
45
using AMLPipelineBase.AbsTypes
56
export fit, fit!, transform, transform!,fit_transform, fit_transform!
@@ -35,39 +36,54 @@ import AMLPipelineBase.AbsTypes: fit!, transform!
3536

3637
# --------------------------------------------
3738

38-
include("skpreprocessor.jl")
39+
@setup_workload begin
40+
@compile_workload begin
41+
include("skpreprocessor.jl")
42+
end
43+
end
3944
using .SKPreprocessors
4045
export SKPreprocessor, skpreprocessors
4146

42-
include("sklearners.jl")
47+
@setup_workload begin
48+
@compile_workload begin
49+
include("sklearners.jl")
50+
end
51+
end
4352
using .SKLearners
4453
export SKLearner, sklearners
4554

46-
include("skcrossvalidator.jl")
55+
@setup_workload begin
56+
@compile_workload begin
57+
include("skcrossvalidator.jl")
58+
end
59+
end
4760
using .SKCrossValidators
4861
export crossvalidate
4962

5063
export skoperator
64+
@setup_workload begin
65+
@compile_workload begin
66+
function skoperator(name::String; args...)::Machine
67+
sklr = keys(SKLearners.learner_dict)
68+
skpr = keys(SKPreprocessors.preprocessor_dict)
69+
if name sklr
70+
obj = SKLearner(name; args...)
71+
elseif name skpr
72+
obj = SKPreprocessor(name; args...)
73+
else
74+
skoperator()
75+
throw(ArgumentError("$name does not exist"))
76+
end
77+
return obj
78+
end
5179

52-
function skoperator(name::String; args...)::Machine
53-
sklr = keys(SKLearners.learner_dict)
54-
skpr = keys(SKPreprocessors.preprocessor_dict)
55-
if name sklr
56-
obj = SKLearner(name; args...)
57-
elseif name skpr
58-
obj = SKPreprocessor(name; args...)
59-
else
60-
skoperator()
61-
throw(ArgumentError("$name does not exist"))
62-
end
63-
return obj
64-
end
65-
66-
function skoperator()
67-
sklr = keys(SKLearners.learner_dict)
68-
skpr = keys(SKPreprocessors.preprocessor_dict)
69-
println("Please choose among these pipeline elements:")
70-
println([sklr..., skpr...])
80+
function skoperator()
81+
sklr = keys(SKLearners.learner_dict)
82+
skpr = keys(SKPreprocessors.preprocessor_dict)
83+
println("Please choose among these pipeline elements:")
84+
println([sklr..., skpr...])
85+
end
86+
end
7187
end
7288

7389
end # module

src/skcrossvalidator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function __init__()
4646
metric_dict["mean_squared_error"] = SKM.mean_squared_error
4747
metric_dict["mean_squared_log_error"] = SKM.mean_squared_log_error
4848
metric_dict["mean_absolute_error"] = SKM.mean_absolute_error
49-
metric_dict["median_absolute_error"] = SKM.median_absolute_error
49+
#metric_dict["median_absolute_error"] = SKM.median_absolute_error
5050
metric_dict["r2_score"] = SKM.r2_score
5151
metric_dict["max_error"] = SKM.max_error
5252
metric_dict["mean_poisson_deviance"] = SKM.mean_poisson_deviance

test/test_skcrossvalidator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function crossval_reg(ppl,X,Y,folds,verbose)
2121
@test crossvalidate(ppl,X,Y,"mean_squared_error",folds,verbose).mean < 0.5
2222
@test crossvalidate(ppl,X,Y,"mean_squared_log_error",folds,verbose).mean < 0.5
2323
@test crossvalidate(ppl,X,Y,"mean_absolute_error",folds,verbose).mean < 0.5
24-
@test crossvalidate(ppl,X,Y,"median_absolute_error",folds,verbose).mean < 0.5
24+
#@test crossvalidate(ppl,X,Y,"median_absolute_error",folds,verbose).mean < 0.5
2525
@test crossvalidate(ppl,X,Y,"r2_score",folds,verbose).mean > 0.50
2626
@test crossvalidate(ppl,X,Y,"max_error",folds,verbose).mean < 0.7
2727
@test crossvalidate(ppl,X,Y,"mean_poisson_deviance",folds,verbose).mean < 0.7

0 commit comments

Comments
 (0)