Skip to content

Commit c602273

Browse files
committed
add data front end for RidgeRegression
1 parent fee2ac1 commit c602273

File tree

4 files changed

+22
-36
lines changed

4 files changed

+22
-36
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
99
IsURL = "ceb4388c-583f-448d-bb30-00b11e8c5682"
1010
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
11+
LearnDataFrontEnds = "5cca22a3-9356-470e-ba1b-8268d0135a4b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
1314
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

src/LearnTestAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ using LinearAlgebra
4646
using Random
4747
using Statistics
4848
using UnPack
49+
import LearnDataFrontEnds
4950

5051
include("tools.jl")
5152
include("logging.jl")

src/learners/regression.jl

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
using LearnAPI
77
using Tables
88
using LinearAlgebra
9+
import LearnDataFrontEnds as FrontEnds
910

1011
# # NAIVE RIDGE REGRESSION WITH NO INTERCEPTS
1112

12-
# We overload `obs` to expose internal representation of data. See later for a simpler
13-
# variation using the `obs` fallback.
13+
# We implement a canned data front end. See `BabyRidgeRegressor` below for a no-frills
14+
# version.
1415

16+
# no docstring here; that goes with the constructor
1517
struct Ridge
1618
lambda::Float64
1719
end
@@ -43,28 +45,25 @@ Base.getindex(data::RidgeFitObs, I) =
4345
RidgeFitObs(data.A[:,I], data.names, data.y[I])
4446
Base.length(data::RidgeFitObs) = length(data.y)
4547

46-
# observations for consumption by `fit`:
47-
function LearnAPI.obs(::Ridge, data)
48-
X, y = data
49-
table = Tables.columntable(X)
50-
names = Tables.columnnames(table) |> collect
51-
RidgeFitObs(Tables.matrix(table)', names, y)
52-
end
48+
# add a canned data front end; `obs` will return objects of type `FrontEnds.Obs`:
49+
const frontend = FrontEnds.Saffron(view=true)
50+
LearnAPI.obs(learner::Ridge, data) = FrontEnds.fitobs(learner, data, frontend)
51+
LearnAPI.obs(model::RidgeFitted, data) = obs(model, data, frontend)
5352

54-
# for involutivity:
55-
LearnAPI.obs(::Ridge, data::RidgeFitObs) = data
53+
# training data deconstructors:
54+
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, data, frontend)
55+
LearnAPI.target(learner::Ridge, data) = LearnAPI.target(learner, data, frontend)
5656

57-
# for observations:
58-
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
57+
function LearnAPI.fit(learner::Ridge, observations::FrontEnds.Obs; verbosity=1)
5958

6059
# unpack hyperparameters and data:
6160
lambda = learner.lambda
62-
A = observations.A
61+
A = observations.features
6362
names = observations.names
64-
y = observations.y
63+
y = observations.target
6564

6665
# apply core learner:
67-
coefficients = (A*A' + learner.lambda*I)\(A*y) # 1 x p matrix
66+
coefficients = (A*A' + learner.lambda*I)\(A*y) # p x 1 matrix
6867

6968
# determine crude feature importances:
7069
feature_importances =
@@ -78,28 +77,13 @@ function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
7877
return RidgeFitted(learner, coefficients, feature_importances, names)
7978

8079
end
81-
82-
# for unprocessed `data = (X, y)`:
8380
LearnAPI.fit(learner::Ridge, data; kwargs...) =
8481
fit(learner, obs(learner, data); kwargs...)
8582

86-
# extracting stuff from training data:
87-
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
88-
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
89-
LearnAPI.target(learner::Ridge, data) =
90-
LearnAPI.target(learner, obs(learner, data))
91-
92-
# observations for consumption by `predict`:
93-
LearnAPI.obs(::RidgeFitted, X) = Tables.matrix(X)'
94-
LearnAPI.obs(::RidgeFitted, X::AbstractMatrix) = X
95-
96-
# matrix input:
97-
LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
98-
observations'*model.coefficients
99-
100-
# tabular input:
101-
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
102-
predict(model, Point(), obs(model, Xnew))
83+
LearnAPI.predict(model::RidgeFitted, ::Point, observations::FrontEnds.Obs) =
84+
(observations.features)'*model.coefficients
85+
LearnAPI.predict(model::RidgeFitted, kind_of_proxy, data) =
86+
LearnAPI.predict(model, kind_of_proxy, obs(model, data))
10387

10488
# accessor function:
10589
LearnAPI.feature_importances(model::RidgeFitted) = model.feature_importances

test/learners/regression.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ learner = LearnTestAPI.Ridge(lambda=0.5)
2626
@test :(LearnAPI.obs) in LearnAPI.functions(learner)
2727

2828
@test LearnAPI.target(learner, data) == y
29-
@test LearnAPI.features(learner, data) == X
29+
@test LearnAPI.features(learner, data).features == Tables.matrix(X)'
3030

3131
# verbose fitting:
3232
@test_logs(

0 commit comments

Comments
 (0)