Skip to content

Commit bde0968

Browse files
tylerjthomas9Tyler Thomasericphanson
authored
MLJ Integration (#16)
* Initial MLJ Interface * add MLJModelInterface to [compact] * MLJ support for CatBoostClassifier * add MLJTestInterface * reformat code, update documentation * refactor to have MLJ Interface in separate module * switch ci cache to julia-actions/cache * add save/restore methods to MLJ tests * expand test coverage * bump to actions/checkoutv3 * python api naming -> wrapper * MLJInterface -> MLJCatBoostInterface * update with ablaom comments * AbstractString -> Textual * fix Textual * updates from ablaom's feedback * replace DataFrames.jl with Tables.jl * Manually drop old OrderedFactor cols, use MMI.int * fix table indexing on Julia v1.6 * fix formatting * remove unnecissary line * initial MMI.update and MLJ data front-end * Dict -> NamedTuple, fix MMI.selectrows * Refactor data processing to utilize CatBoost Pools * Fix `prepare_input` return type (not tuple) * add Default parameters * format files * change MMI.update to compare Julia structs * update docstrings, feature_importances * fix missing comma in `MMI.selectrows` * fix `selectrows` indexing * bump actions versions * use julia cache for docs * fix verbose logic * propagate `first(y)` for `CatBoostClassifier` * fix formatting * Adjust `MMI.UnivariateFinite` pool * expand catboost classifier `selectrows` support * docstring adjustments * add default `iteration_parameter` * Update Project.toml --------- Co-authored-by: Tyler Thomas <[email protected]> Co-authored-by: Eric Hanson <[email protected]>
1 parent 9cab2d6 commit bde0968

30 files changed

+1015
-128
lines changed

.github/workflows/CI.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,28 @@ jobs:
2626
fail-fast: false
2727
matrix:
2828
version:
29-
- '1'
29+
- '1.6'
30+
- '1'
31+
- 'nightly'
3032
os:
3133
- ubuntu-latest
3234
arch:
3335
- x64
36+
include:
37+
- os: windows-latest
38+
version: '1'
39+
arch: x64
3440
env:
3541
PYTHON: ''
3642
steps:
37-
- uses: actions/checkout@v2
43+
- uses: actions/checkout@v3
3844
with:
3945
fetch-depth: 0
4046
- uses: julia-actions/setup-julia@v1
4147
with:
4248
version: ${{ matrix.version }}
4349
arch: ${{ matrix.arch }}
44-
- uses: actions/cache@v2
45-
with:
46-
path: ~/.julia/artifacts
47-
key: ${{ runner.os }}-test-artifacts-${{ hashFiles('**/Project.toml') }}
48-
restore-keys: ${{ runner.os }}-test-artifacts
50+
- uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache
4951
- uses: julia-actions/julia-buildpkg@v1
5052
- uses: julia-actions/julia-runtest@v1
5153
- uses: julia-actions/julia-processcoverage@v1

.github/workflows/docs.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,11 @@ jobs:
1818
Documentation:
1919
runs-on: ubuntu-latest
2020
steps:
21-
- uses: actions/checkout@v2
21+
- uses: actions/checkout@v3
2222
- uses: julia-actions/setup-julia@latest
2323
with:
2424
version: 1.6 # earliest supported version
25-
- uses: actions/cache@v2
26-
with:
27-
path: ~/.julia/artifacts
28-
key: ${{ runner.os }}-docs-artifacts-${{ hashFiles('**/Project.toml') }}
29-
restore-keys: ${{ runner.os }}-docs-artifacts
25+
- uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache
3026
- uses: julia-actions/julia-docdeploy@releases/v1
3127
env:
3228
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token

.github/workflows/format_check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- uses: julia-actions/setup-julia@latest
1717
with:
1818
version: 1.6.0
19-
- uses: actions/checkout@v1
19+
- uses: actions/checkout@v3
2020
- name: Instantiate `format` environment and format
2121
run: |
2222
julia --project=format -e 'using Pkg; Pkg.instantiate()'

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
name = "CatBoost"
22
uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12"
33
authors = ["Beacon Biosignals, Inc."]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
7-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
88
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
99
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1010
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1111

1212
[compat]
1313
Aqua = "0.5"
14-
DataFrames = "0.22, 1"
14+
MLJModelInterface = "1"
1515
OrderedCollections = "1.4"
1616
PythonCall = "0.9"
1717
Tables = "1.4"
1818
julia = "1.6"
1919

2020
[extras]
2121
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
22+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
23+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
24+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
2225
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2326

2427
[targets]
25-
test = ["Aqua", "Test"]
28+
test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "Test"]

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Julia interface to [CatBoost](https://catboost.ai/).
1616
module Regression
1717

1818
using CatBoost
19+
using PythonCall
1920

2021
train_data = PyList([[1, 4, 5, 6], [4, 5, 6, 7], [30, 40, 50, 60]])
2122
eval_data = PyList([[2, 4, 6, 8], [1, 4, 50, 60]])
@@ -32,3 +33,28 @@ preds = predict(model, eval_data)
3233

3334
end # module
3435
```
36+
37+
## MLJ Example
38+
```julia
39+
module Regression
40+
41+
using CatBoost
42+
using DataFrames
43+
using MLJBase
44+
45+
train_data = DataFrame([[1,4,30], [4,5,40], [5,6,50], [6,7,60]], :auto)
46+
eval_data = DataFrame([[2,1], [4,4], [6,50], [8,60]], :auto)
47+
train_labels = [10.0, 20.0, 30.0]
48+
49+
# Initialize MLJ Machine
50+
model = CatBoostRegressor(iterations = 2, learning_rate = 1, depth = 2)
51+
mach = machine(model, train_data, train_labels)
52+
53+
# Fit model
54+
MLJBase.fit!(mach)
55+
56+
# Get predictions
57+
preds = predict(model, eval_data)
58+
59+
end # module
60+
```

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using CatBoost
22
using Documenter
33

44
makedocs(; modules=[CatBoost], sitename="CatBoost.jl", authors="Beacon Biosignals, Inc.",
5-
pages=["API Documentation" => "index.md"])
5+
pages=["Introduction" => "index.md", "Wrapper" => "wrapper.md",
6+
"MLJ API" => "mlj_api.md"])
67

78
deploydocs(; repo="github.com/beacon-biosignals/CatBoost.jl.git", push_preview=true,
89
devbranch="main")

docs/src/index.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
# API Documentation
1+
# CatBoost.jl
22

3-
Below is the API documentation for CatBoost.jl.
3+
Julia interface to [CatBoost](https://catboost.ai/). This library is a wrapper CatBoost's Python package via [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl).
44

55
For a nice introduction to the package, see the [examples](https://github.com/beacon-biosignals/CatBoost.jl/blob/main/examples/).
66

7-
```@meta
8-
CurrentModule = CatBoost
7+
# Installation
8+
9+
This package is available in the Julia General Registry. You can install it with either of the following commands:
10+
11+
```
12+
pkg> add CatBoost
913
```
1014

11-
```@autodocs
12-
Modules = [CatBoost]
15+
```julia
16+
julia> using Pkg; Pkg.add("CatBoost")
1317
```

docs/src/mlj_api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# MLJ API
2+
3+
Below is the MLJ API documentation for CatBoost.jl.
4+
5+
```@docs
6+
CatBoost.MLJCatBoostInterface.CatBoostClassifier
7+
CatBoost.MLJCatBoostInterface.CatBoostRegressor
8+
```

docs/src/wrapper.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Python Wrapper
2+
3+
Below is the Python wrapper documentation for CatBoost.jl.
4+
5+
```@docs
6+
Pool
7+
CatBoost.CatBoostClassifier
8+
CatBoost.CatBoostRegressor
9+
cv
10+
to_catboost
11+
to_pandas
12+
pandas_to_tbl
13+
feature_importance
14+
load_dataset
15+
```

examples/mlj/binary.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
module Binary
2+
3+
using CatBoost.MLJCatBoostInterface
4+
using DataFrames
5+
using MLJBase
6+
using PythonCall
7+
8+
# Initialize data
9+
train_data = DataFrame([coerce(["a", "a", "c"], Multiclass),
10+
coerce(["b", "b", "d"], Multiclass),
11+
coerce([0, 0, 1], OrderedFactor), [4, 5, 40], [5, 6, 50],
12+
[6, 7, 60]], :auto)
13+
train_labels = coerce([1, 1, -1], OrderedFactor)
14+
eval_data = DataFrame([coerce(["a", "a"], Multiclass), coerce(["b", "d"], Multiclass),
15+
coerce([0, 0], OrderedFactor), [4, 4], [6, 50], [8, 60]], :auto)
16+
17+
# Initialize CatBoostClassifier
18+
model = CatBoostClassifier(; iterations=2, learning_rate=1.0, depth=2)
19+
mach = machine(model, train_data, train_labels)
20+
21+
# Fit model
22+
MLJBase.fit!(mach)
23+
24+
# Get predicted classes
25+
preds_class = MLJBase.predict_mode(mach, eval_data)
26+
27+
# Get predicted probabilities for each class
28+
preds_proba = MLJBase.predict(mach, eval_data)
29+
30+
end # module

0 commit comments

Comments
 (0)