Skip to content

Commit ed297dc

Browse files
oleksiyskononenkost-pasha
authored andcommitted
Convert FTRL code to a core datatable class (#1417)
FTRL code has been converted into a core datatable class: - added `fit` and `predict` methods for training and making predictions; - added `reset` method to reset the model; - added getters and setters for all the model parameters; - the `model` itself can now be returned and loaded as a separate frame; - some other refactoring and modifications. The algo itself was also slightly optimized, however, it still gets the same score on Kaggle. So we're ok from the accuracy point of view.
1 parent 67abbdd commit ed297dc

File tree

9 files changed

+758
-192
lines changed

9 files changed

+758
-192
lines changed

c/datatablemodule.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include "csv/writer.h"
1414
#include "expr/py_expr.h"
1515
#include "extras/aggregator.h"
16-
#include "extras/ftrl.h"
16+
#include "extras/py_ftrl.h"
1717
#include "frame/py_frame.h"
1818
#include "options.h"
1919
#include "py_column.h"
@@ -199,7 +199,6 @@ void DatatableModule::init_methods() {
199199
add(METHODv(expr_unaryop));
200200
add(METHOD0(is_debug_mode));
201201
add(METHOD0(has_omp_support));
202-
init_methods_ftrl();
203202
init_methods_aggregate();
204203
init_methods_str();
205204
init_methods_options();
@@ -234,6 +233,8 @@ PyInit__datatable()
234233

235234
try {
236235
py::Frame::Type::init(m);
236+
py::Ftrl::Type::init(m);
237+
237238
} catch (const std::exception& e) {
238239
exception_to_python(e);
239240
return nullptr;

c/datatablemodule.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class DatatableModule : public py::ExtModule<DatatableModule> {
2929

3030
void init_methods();
3131
void init_methods_aggregate();// extra/aggergate.cc
32-
void init_methods_ftrl(); // extra/ftrl.cc
3332
void init_methods_str(); // str/py_str.cc
3433
void init_methods_options(); // options.cc
3534
void init_methods_sets(); // set_funcs.cc

c/extras/aggregator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ namespace py {
4545

4646
// dt changes in-place with a new column added to the end of it
4747
DataTable* dt_members = agg.aggregate(dt).release();
48-
py::Frame* frame_members = py::Frame::from_datatable(dt_members);
48+
py::oobj df_members = py::oobj::from_new_reference(py::Frame::from_datatable(dt_members));
4949

50-
return frame_members;
50+
return df_members;
5151
}
5252
);
5353
}

0 commit comments

Comments
 (0)