Skip to content

Commit 5cdee87

Browse files
Fix MLJ example in docs (#33)
1 parent 28930d4 commit 5cdee87

File tree

5 files changed

+14
-15
lines changed

5 files changed

+14
-15
lines changed

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,20 @@ using CatBoost.MLJCatBoostInterface
5656
using DataFrames
5757
using MLJBase
5858

59-
train_data = DataFrame([[1,4,30], [4,5,40], [5,6,50], [6,7,60]], :auto)
60-
eval_data = DataFrame([[2,1], [4,4], [6,50], [8,60]], :auto)
61-
train_labels = [10.0, 20.0, 30.0]
59+
# Initialize data
60+
train_data = DataFrame([[1, 4, 30], [4, 5, 40], [5, 6, 50], [6, 7, 60]], :auto)
61+
train_labels = [10.0, 20.0, 30.0]
62+
eval_data = DataFrame([[2, 1], [4, 4], [6, 50], [8, 60]], :auto)
6263

63-
# Initialize MLJ Machine
64-
model = CatBoostRegressor(iterations = 2, learning_rate = 1, depth = 2)
64+
# Initialize CatBoostClassifier
65+
model = CatBoostRegressor(; iterations=2, learning_rate=1.0, depth=2)
6566
mach = machine(model, train_data, train_labels)
6667

6768
# Fit model
6869
MLJBase.fit!(mach)
6970

7071
# Get predictions
71-
preds = predict(model, eval_data)
72+
preds_class = MLJBase.predict(mach, eval_data)
7273

7374
end # module
7475
```

docs/src/index.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,20 @@ using CatBoost.MLJCatBoostInterface
4848
using DataFrames
4949
using MLJBase
5050

51-
train_data = DataFrame([[1,4,30], [4,5,40], [5,6,50], [6,7,60]], :auto)
52-
eval_data = DataFrame([[2,1], [4,4], [6,50], [8,60]], :auto)
53-
train_labels = [10.0, 20.0, 30.0]
51+
# Initialize data
52+
train_data = DataFrame([[1, 4, 30], [4, 5, 40], [5, 6, 50], [6, 7, 60]], :auto)
53+
train_labels = [10.0, 20.0, 30.0]
54+
eval_data = DataFrame([[2, 1], [4, 4], [6, 50], [8, 60]], :auto)
5455

55-
# Initialize MLJ Machine
56-
model = CatBoostRegressor(iterations = 2, learning_rate = 1, depth = 2)
56+
# Initialize CatBoostClassifier
57+
model = CatBoostRegressor(; iterations=2, learning_rate=1.0, depth=2)
5758
mach = machine(model, train_data, train_labels)
5859

5960
# Fit model
6061
MLJBase.fit!(mach)
6162

6263
# Get predictions
63-
preds = predict(model, eval_data)
64+
preds_class = MLJBase.predict(mach, eval_data)
6465

6566
end # module
6667
```

examples/mlj/binary.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module Binary
33
using CatBoost.MLJCatBoostInterface
44
using DataFrames
55
using MLJBase
6-
using PythonCall
76

87
# Initialize data
98
train_data = DataFrame([coerce(["a", "a", "c"], Multiclass),

examples/mlj/multiclass.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module Multiclass
33
using CatBoost.MLJCatBoostInterface
44
using DataFrames
55
using MLJBase
6-
using PythonCall
76

87
# Initialize data
98
train_data = DataFrame([coerce(["a", "a", "c"], MLJBase.Multiclass),

examples/mlj/regression.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module Regression
33
using CatBoost.MLJCatBoostInterface
44
using DataFrames
55
using MLJBase
6-
using PythonCall
76

87
# Initialize data
98
train_data = DataFrame([[1, 4, 30], [4, 5, 40], [5, 6, 50], [6, 7, 60]], :auto)

0 commit comments

Comments
 (0)