Skip to content

Commit 6da5fbc

Browse files
use ExtraTrees as default base learner
1 parent 24918ff commit 6da5fbc

File tree

56 files changed

+40
-31
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+40
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "maturin"
44

55
[project]
66
name = "genbooster"
7-
version = "0.4.0"
7+
version = "0.5.0"
88
description = "A fast boosting implementation using Rust and Python"
99
requires-python = ">=3.7"
1010
authors = [

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="genbooster",
5-
version="0.4.0",
5+
version="0.5.0",
66
packages=find_packages(where="src"),
77
package_dir={"": "src"},
88
install_requires=[

src/genbooster/genboosterclassifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
66
from sklearn.preprocessing import StandardScaler, OneHotEncoder
77
from sklearn.linear_model import Ridge
8+
from sklearn.tree import ExtraTreeRegressor
89
from .rust_core import RustBooster as _RustBooster
910

1011

@@ -13,7 +14,7 @@ class BoosterClassifier(BaseEstimator, ClassifierMixin):
1314
1415
Parameters:
1516
16-
base_estimator: Base learner to use for the booster.
17+
base_estimator: Base learner to use for the booster. Default is ExtraTreeRegressor.
1718
1819
n_estimators: Number of boosting stages to perform.
1920
@@ -56,7 +57,7 @@ def __init__(self,
5657
tolerance: float = 1e-4,
5758
random_state: Optional[int] = 42):
5859
if base_estimator is None:
59-
self.base_estimator = Ridge()
60+
self.base_estimator = ExtraTreeRegressor()
6061
else:
6162
self.base_estimator = base_estimator
6263
self.n_estimators = n_estimators

src/genbooster/genboosterregressor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
66
from sklearn.preprocessing import StandardScaler, OneHotEncoder
77
from sklearn.linear_model import Ridge
8+
from sklearn.tree import ExtraTreeRegressor
89
from .rust_core import RustBooster as _RustBooster
910

11+
1012
class BoosterRegressor(BaseEstimator, RegressorMixin):
1113
"""Generic Gradient Boosting Regressor (for any base learner).
1214
@@ -57,6 +59,10 @@ def __init__(
5759
random_state: Optional[int] = 42
5860
):
5961
self.base_estimator = base_estimator
62+
if base_estimator is None:
63+
self.base_estimator = ExtraTreeRegressor()
64+
else:
65+
self.base_estimator = base_estimator
6066
self.n_estimators = n_estimators
6167
self.learning_rate = learning_rate
6268
self.n_hidden_features = n_hidden_features
3.48 MB
Binary file not shown.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cd815e9d9fa64be9
1+
75e3a61205c0979d
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"rustc":8308448953891780981,"features":"[]","declared_features":"[]","target":7775661509111195289,"profile":7481509334846997914,"path":17777289886553719987,"deps":[[1773036428841864304,"linfa",false,13796361333782155478],[2307970661281702218,"numpy",false,9177328553660719973],[5910892534286594076,"rand",false,1642172812197523942],[7893086541683815303,"linfa_elasticnet",false,10738206144616120958],[11491204902755878355,"pyo3",false,9717495171982902883],[15437040057372272760,"linfa_linear",false,18104169532794384664],[16880308220206177859,"ndarray",false,13822828271593039946],[17246046207308087067,"linfa_pls",false,17110029286974277900]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/genbooster-53dac35c39213789/dep-lib-genbooster"}}],"rustflags":[],"metadata":7797948686568424061,"config":2202906307356721367,"compile_kind":0}
1+
{"rustc":8308448953891780981,"features":"[]","declared_features":"[]","target":7775661509111195289,"profile":7631824717372466827,"path":17777289886553719987,"deps":[[1773036428841864304,"linfa",false,13796361333782155478],[2307970661281702218,"numpy",false,5246450974553583619],[5910892534286594076,"rand",false,1642172812197523942],[7893086541683815303,"linfa_elasticnet",false,10738206144616120958],[11491204902755878355,"pyo3",false,10049241876928269031],[15437040057372272760,"linfa_linear",false,18104169532794384664],[16880308220206177859,"ndarray",false,13822828271593039946],[17246046207308087067,"linfa_pls",false,17110029286974277900]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/genbooster-53dac35c39213789/dep-lib-genbooster"}}],"rustflags":[],"metadata":7797948686568424061,"config":2202906307356721367,"compile_kind":0}

target/debug/.fingerprint/genbooster-53dac35c39213789/output-lib-genbooster

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
63469d985f7adb86
1+
e70a4a703e14768b
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"rustc":8308448953891780981,"features":"[\"default\", \"extension-module\", \"indoc\", \"macros\", \"pyo3-macros\", \"unindent\"]","declared_features":"[\"abi3\", \"abi3-py310\", \"abi3-py311\", \"abi3-py37\", \"abi3-py38\", \"abi3-py39\", \"anyhow\", \"auto-initialize\", \"chrono\", \"default\", \"experimental-inspect\", \"extension-module\", \"eyre\", \"full\", \"generate-import-lib\", \"hashbrown\", \"indexmap\", \"indoc\", \"inventory\", \"macros\", \"multiple-pymethods\", \"nightly\", \"num-bigint\", \"num-complex\", \"pyo3-macros\", \"rust_decimal\", \"serde\", \"unindent\"]","target":3917981434725704152,"profile":6609184196851301694,"path":12318942634921028097,"deps":[[2452538001284770427,"cfg_if",false,7805845475459972740],[4198017634353570385,"parking_lot",false,15131708163124596309],[5748077777622396073,"memoffset",false,16724842799913194657],[7780729136333935213,"libc",false,8626046537738359144],[8413899861287479280,"unindent",false,13788096523579970376],[9535745648096932827,"pyo3_macros",false,16312271500822600982],[11491204902755878355,"build_script_build",false,14326852061224627412],[14558274942756651873,"pyo3_ffi",false,18417579364376809910],[18062083972812265452,"indoc",false,13561862281385562631]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/pyo3-3b8f58838dd2d5fe/dep-lib-pyo3"}}],"rustflags":[],"metadata":13602786906243840833,"config":2202906307356721367,"compile_kind":0}
1+
{"rustc":8308448953891780981,"features":"[\"default\", \"extension-module\", \"indoc\", \"macros\", \"pyo3-macros\", \"unindent\"]","declared_features":"[\"abi3\", \"abi3-py310\", \"abi3-py311\", \"abi3-py37\", \"abi3-py38\", \"abi3-py39\", \"anyhow\", \"auto-initialize\", \"chrono\", \"default\", \"experimental-inspect\", \"extension-module\", \"eyre\", \"full\", \"generate-import-lib\", \"hashbrown\", \"indexmap\", \"indoc\", \"inventory\", \"macros\", \"multiple-pymethods\", \"nightly\", \"num-bigint\", \"num-complex\", \"pyo3-macros\", \"rust_decimal\", \"serde\", \"unindent\"]","target":3917981434725704152,"profile":6609184196851301694,"path":12318942634921028097,"deps":[[2452538001284770427,"cfg_if",false,7805845475459972740],[4198017634353570385,"parking_lot",false,15131708163124596309],[5748077777622396073,"memoffset",false,16724842799913194657],[7780729136333935213,"libc",false,8626046537738359144],[8413899861287479280,"unindent",false,13788096523579970376],[9535745648096932827,"pyo3_macros",false,16312271500822600982],[11491204902755878355,"build_script_build",false,16766450166957907201],[14558274942756651873,"pyo3_ffi",false,510530259404512936],[18062083972812265452,"indoc",false,13561862281385562631]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/pyo3-3b8f58838dd2d5fe/dep-lib-pyo3"}}],"rustflags":[],"metadata":13602786906243840833,"config":2202906307356721367,"compile_kind":0}

0 commit comments

Comments
 (0)