Skip to content

Commit 617a995

Browse files
authored
Merge pull request #966 from mlr-org/learner_cv_average
Learner cv average
2 parents 73a4fe6 + c612020 commit 617a995

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2069
-89
lines changed

.Rbuildignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@
2525
^\.vscode$
2626
^\.lintr$
2727
^\.pre-commit-config\.yaml$
28+
^AGENTS\.md$
29+
^CLAUDE\.md$

AGENTS.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
2+
<persistence>
3+
1. If the user asked you a question, try to gather information and answer the question to the best of your ability.
4+
2. If the user asked you to review code, work and gather the required information to give a code review according to the `<guiding_principles>` and general best practices. Do not ask any more questions, just provide a best effort code review.
5+
3. Otherwise:
6+
- You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user.
7+
- If the instructions are unclear, try to think of what info you need and gather that info from the user *right away*, so you can then work autonomouslyf for many turns.
8+
- Be extra-autonomous. The user wants you to work on your own, once you started.
9+
- Only terminate your turn when you are sure that the problem is solved.
10+
- Never stop or hand back to the user when you encounter uncertainty - research or deduce the most reasonable approach and continue.
11+
- Do not ask the human to confirm or clarify assumptions except at the very beginning, as this can always be adjusted later - decide what the most reasonable assumption is, proceed with it, and document it for the user's reference after you finish acting
12+
- You are working inside a secure container, you cannot break anything vital, so do not ask for permission and be bold.
13+
</persistence>
14+
<work_loop>
15+
- At the beginning:
16+
- When asked a question about the code or in general, or asked for code review, gather the necessary information and answer right away and finish.
17+
- When instructions are unclear, ask clarifying questions at the beginning.
18+
- During work:
19+
- Think before you act. Plan ahead. Feel free to think more than you would otherwise; look at things from different angles, consider different scenarios.
20+
- If possible, write a few tests *before* implementing a feature or fixing a bug.
21+
- For a bug fix, write a test that captures the bug before fixing the bug.
22+
- For a feature, create tests to the degree it is possible. Try really hard. If it is not possible, at least create test-stubs in the form of empty `test_that()` blocks to be filled in later.
23+
- Tests should be sensibly thorough. Write more thorough tests only when asked by the user to write tests.
24+
- Work and solve upcoming issues independently, using your best judgment
25+
- Package progress into organic git commits. You may overwrite commits that are not on 'origin' yet, but do so only if it has great benefit. If you are on git branch `master`, create a new aptly named branch; never commit into `master`. Otherwise, do not leave the current git branch.
26+
- Again: create git commits at organic points. In the past, you tended to make too few git commits.
27+
- If any issues pop up:
28+
- If you noticed any things that surprised you, anything that would have helped you substantially with your work if you had known it right away, add it to the `<agent_notes>` section of the `AGENTS.md` file. Future agents will then have access to this information. Use it to capture technical insights, failed approaches, user preferences, and other things future agents should know.
29+
- After feature implementation, write tests:
30+
- If you were asked to implement a feature and have not yet done so, fill in the test_that stubs created earlier or create new tests, to the degree that they make sense.
31+
- If you were asked to fix a bug, check again that there are regression tests.
32+
- When you are done:
33+
- Write a short summary of what you did, and what decisions you had to make that went beyond what the user asked of you, and other things the user should know about, as chat response to the user.
34+
- Unless you were working on something minor, or you are leaving things as an obvious work-in-progress, do a git commit.
35+
</work_loop>
36+
<debugging>
37+
When fixing problems, always make sure you know the actual reason of the problem first:
38+
39+
1. Form hypotheses about what the issue could be.
40+
2. Find a way to test these hypotheses and test them. If necessary, ask for assistance from the human, who e.g. may need to interact manually with the software
41+
3. If you accept a hypothesis, apply an appropriate fix. The fix may not work and the hypothesis may turn out to be false; in that case, undo the fix unless it actually improves code quality overall. Do not leave unnecessary fixes for imaginary issues that never materialized clog up the code.
42+
</debugging>
43+
<guiding_principles>
44+
Straightforwardness: Avoid ideological adherence to other programming principles when something can be solved in a simple, short, straightforward way. Otherwise:
45+
46+
- Simplicity: Favor small, focused components and avoid unnecessary complexity in design or logic.
47+
- This also means: avoid overly defensive code. Observe the typical level of defensiveness when looking at the code.
48+
- Idiomaticity: Solve problems the way they "should" be solved, in the respective language: the way a professional in that language would have approached it.
49+
- Readability and maintainability are primary concerns, even at the cost of conciseness or performance.
50+
- Doing it right is better than doing it fast. You are not in a rush. Never skip steps or take shortcuts.
51+
- Tedious, systematic work is often the correct solution. Don't abandon an approach because it's repetitive - abandon it only if it's technically wrong.
52+
- Honesty is a core value. Be honest about changes you have made and potential negative effects, these are okay. Be honest about shortcomings of other team members' plans and implementations, we all care more about the project than our egos. Be honest if you don't know something: say "I don't know" when appropriate.
53+
</guiding_principles>
54+
<project_info>
55+
56+
`mlr3pipelines` is a package that extends the `mlr3` ecosystem by adding preprocessing operations and a way to compose them into computational graphs.
57+
58+
- The package is very object-oriented; most things use R6.
59+
- Coding style: we use `snake_case` for variables, `UpperCamelCase` for R6 classes. We use `=` for assignment and mostly use the tidyverse style guide otherwise. We use block-indent (two spaces), *not* visual indent; i.e., we don't align code with opening parentheses in function calls, we align by block depth.
60+
- User-facing API (`@export`ed things, public R6 methods) always need checkmate `asserts_***()` argument checks. Otherwise don't be overly defensive, look at the other code in the project to see our esired level of paranoia.
61+
- Always read at least `R/PipeOp.R` and `R/PipeOpTaskPreproc.R` to see the base classes you will need in almost every task.
62+
- Read `R/Graph.R` and `R/GraphLearner.R` to understand the Graph architecture.
63+
- Before you start coding, look at other relevant `.R` files that do something similar to what you are supposed to implement.
64+
- We use `testthat`, and most test files are in `tests/testthat/`. Read the additional important helpers in `inst/testthat/helper_functions.R` to understand our `PipeOpTaskPreproc` auto-test framework.
65+
- Always write tests, execute them with `devtools::test(filter = )` ; the entirety of our tests take a long time, so only run tests for what you just wrote.
66+
- Tests involving the `$man` field, and tests involving parallelization, do not work well when the package is loaded with `devtools::load_all()`, because of conflicts with the installed version. Ignore these failures, CI will take care of this.
67+
- The quality of our tests is lower than it ideally should be. We are in the process of improving this over time. Always leave the `tests/testthat/` folder in a better state than what you found it in!
68+
- If `roxygenize()` / `document()` produce warnings that are unrelated to the code you wrote, ignore them. Do not fix code or formatting that is unrelated to what you are working on, but *do* mention bugs or problems that you noticed it in your final report.
69+
- When you write examples, make sure they work.
70+
- A very small number of packages listed in `Suggests:` used by some tests / examples is missing; ignore warnings in that regard. You will never be asked to work on things that require these packages.
71+
- Packages that we rely on; they generally have good documentation thta can be queried, or they can be looked up on GitHub
72+
- `mlr3`, provides `Task`, `Learner`, `Measure`, `Prediction`, various `***Result` classes; basically the foundation on which we build. <https://github.com/mlr-org/mlr3>
73+
- `mlr3misc`, provides a lot of helper functions that we prefer to use over base-R when available. <https://github.com/mlr-org/mlr3misc>
74+
- `paradox`, provides the hyperparameters-/configuration space: `ps()`, `p_int()`, `p_lgl()`, `p_fct()`, `p_uty()` etc. <https://github.com/mlr-org/paradox>
75+
- For the mlr3-ecosystem as a whole, also consider the "mlr3 Book" as a reference, <https://mlr3book.mlr-org.com/>
76+
- Semantics of paradox ParamSet parameters to pay attention to:
77+
- there is a distinction between "default" values and values that a parameter is initialized to: a "default" is the behaviour that happens when the parameter is not given at all; e.g. PipeOpPCA `center` defaults to `TRUE`, since the underlying function (`prcomp`)'s does centering when the `center` argument is not given at all. In contrast, a parameter is "initialized" to some value if it is set to some value upon construction of a PipeOp. In rare cases, this can differ from default, e.g. if the underlying default behaviour is suboptimal for the use for preprocessing (e.g. it stores training data unnecessarily by default).
78+
- a parameter can be marked as "required" by having the tag `"required"`. It is a special tag that causes an error if the value is not set. A "required" parameter *can not* have a "default", since semantically this is a contradiction: "default" would describe what happens when the param is not set, but param-not-set is an error.
79+
- When we write preprocessing method ourselves we usually don't do "default" behaviour and instead mark most things as "required". "default" is mostly if we wrap some other library's function which itself has a function argument default value.
80+
- We initialize a parameter by giving the `p_xxx(init = )` argument. Some old code does `param_set$values = list(...)` or `param_set$values$param = ...` in the constructor. This is deprecated; we do not unnecessarily change it in old code, but new code should have `init = `. A parameter should be documented as "initialized to" something if and only if the value is set through one of these methods in the constructor.
81+
- Inside the train / predict functions of PipeOps, hyperparameter values should be obtained through `pv = self$param_set$get_values(tags = )`, where `tags` is often `"train"`, `"predict"`, or some custom tag that groups hyperparameters by meaning somehow (e.g. everything that should be passed to a specific function). A nice pattern is to call a function `fname` with many options configured through `pv` while also explicitly passing some arguments as `invoke(fname, arg1 = val1, arg2 = val2, .args = pv)`, using `invoke` from `mlr3misc`.
82+
- paradox does type-checking and range-checking automatically; `get_values()` automatically checks that `"required"` params are present and not `NULL`. Therefore, we only do additional parameter feasibility checks in the rarest of cases.
83+
- Minor things to be aware of:
84+
- Errors that are thrown in PipeOps are automatically wrapped by Graph to also mention the PipeOp ID, so it is not necessary to include that in error messages.
85+
86+
</project_info>
87+
<agent_notes>
88+
89+
# Notes by Agents to other Agents
90+
91+
- R unit tests in this repo assume helper `expect_man_exists()` is available. If you need to call it in a new test and you are working without mlr3pipelines installed, define a local fallback at the top of that test file before `expect_learner()` is used.
92+
- Revdep helper scripts live in `attic/revdeps/`. `download_revdeps.R` downloads reverse dependency source tarballs; `install_revdep_suggests.R` installs Suggests for those revdeps without pulling the revdeps themselves.
93+
94+
</agent_notes>
95+
<your_task>
96+
Again, when implementing something, focus on:
97+
98+
1. Think things through and plan ahead.
99+
2. Tests before implementation, if possible. In any case, write high quality tests, try to be better than the tests you find in this project.
100+
3. Once you started, work independently; we can always undo things if necessary.
101+
4. Create sensible intermediate commits.
102+
5. Check your work, make sure tests pass. But do not run *all* tests, they take a long time.
103+
6. Write a report to the user at the end, informing about decisoins that were made autonomously, unexpected issues etc.
104+
</your_task>

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
AGENTS.md

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Config/testthat/edition: 3
107107
Config/testthat/parallel: true
108108
NeedsCompilation: no
109109
Roxygen: list(markdown = TRUE, r6 = FALSE)
110-
RoxygenNote: 7.3.2
110+
RoxygenNote: 7.3.3
111111
VignetteBuilder: knitr, rmarkdown
112112
Collate:
113113
'CnfAtom.R'

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* Fix: Added internal workaround for `PipeOpNMF` attaching `Biobase`, `BiocGenerics`, and `generics` to the search path during training, prediction or when printing its `$state`.
55
* feat: allow dates in datefeatures pipe op and use data.table for date feature generation.
66
* Added support for internal validation tasks to `PipeOpFeatureUnion`.
7+
* feat: `PipeOpLearnerCV` can reuse the cross-validation models during prediction by averaging their outputs (`resampling.predict_method = "cv_ensemble"`).
8+
* feat: `PipeOpRegrAvg` gets new `se_aggr` and `se_aggr_rho` hyperparameters and now allows various forms of SE aggregation.
79

810
# mlr3pipelines 0.9.0
911

@@ -304,4 +306,3 @@
304306
# mlr3pipelines 0.1.0
305307

306308
* Initial upload to CRAN.
307-

R/PipeOpClassifAvg.R

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,23 @@
1111
#' Always returns a `"prob"` prediction, regardless of the incoming [`Learner`][mlr3::Learner]'s
1212
#' `$predict_type`. The label of the class with the highest predicted probability is selected as the
1313
#' `"response"` prediction. If the [`Learner`][mlr3::Learner]'s `$predict_type` is set to `"prob"`,
14-
#' the prediction obtained is also a `"prob"` type prediction with the probability predicted to be a
15-
#' weighted average of incoming predictions.
14+
#' the probability aggregation is controlled by `prob_aggr` (see below). If `$predict_type = "response"`,
15+
#' predictions are internally converted to one-hot probability vectors (point mass on the predicted class) before aggregation.
16+
#'
17+
#' ### `"prob"` aggregation:
18+
#'
19+
#' * **`prob_aggr = "mean"`** -- *Linear opinion pool (arithmetic mean of probabilities; default)*.
20+
#' **Interpretation.** Mixture semantics: choose a base model with probability `w[i]`, then draw from its class distribution.
21+
#' Decision-theoretically, this is the minimizer of `sum(w[i] * KL(p[i] || p))` over probability vectors `p`, where `KL(x || y)` is the Kullback-Leibler divergence.
22+
#' **Typical behavior.** Conservative / better calibrated and robust to near-zero probabilities (never assigns zero unless all do).
23+
#' This is the standard choice for probability averaging in ensembles and stacking.
24+
#'
25+
#' * **`prob_aggr = "log"`** -- *Log opinion pool / product of experts (geometric mean in probability space)*:
26+
#' Average per-model logs (or equivalently, logits) and apply softmax.
27+
#' **Interpretation.** Product semantics: `p_ens ~ prod_i p_i^{w[i]}`; minimizes `sum(w[i] * KL(p || p[i]))`.
28+
#' **Typical behavior.** Sharper / lower entropy (emphasizes consensus regions), but can be **overconfident** and is sensitive
29+
#' to zeros; use `prob_aggr_eps` to clip small probabilities for numerical stability. Often beneficial with strong, similarly
30+
#' calibrated members (e.g., neural networks), less so when calibration is the priority.
1631
#'
1732
#' All incoming [`Learner`][mlr3::Learner]'s `$predict_type` must agree.
1833
#'
@@ -45,7 +60,14 @@
4560
#' The `$state` is left empty (`list()`).
4661
#'
4762
#' @section Parameters:
48-
#' The parameters are the parameters inherited from the [`PipeOpEnsemble`].
63+
#' The parameters are the parameters inherited from the [`PipeOpEnsemble`], as well as:
64+
#' * `prob_aggr` :: `character(1)`\cr
65+
#' Controls how incoming class probabilities are aggregated. One of `"mean"` (linear opinion pool; default) or
66+
#' `"log"` (log opinion pool / product of experts). See the description above for definitions and interpretation.
67+
#' Only has an effect if the incoming predictions have `"prob"` values.
68+
#' * `prob_aggr_eps` :: `numeric(1)`\cr
69+
#' Small positive constant used only for `prob_aggr = "log"` to clamp probabilities before taking logs, improving numerical
70+
#' stability and avoiding `-Inf`. Ignored for `prob_aggr = "mean"`. Default is `1e-12`.
4971
#'
5072
#' @section Internals:
5173
#' Inherits from [`PipeOpEnsemble`] by implementing the `private$weighted_avg_predictions()` method.
@@ -81,7 +103,11 @@ PipeOpClassifAvg = R6Class("PipeOpClassifAvg",
81103
inherit = PipeOpEnsemble,
82104
public = list(
83105
initialize = function(innum = 0, collect_multiplicity = FALSE, id = "classifavg", param_vals = list()) {
84-
super$initialize(innum, collect_multiplicity, id, param_vals = param_vals, prediction_type = "PredictionClassif", packages = "stats")
106+
param_set = ps(
107+
prob_aggr = p_fct(levels = c("mean", "log"), init = "mean", tags = c("predict", "prob_aggr")),
108+
prob_aggr_eps = p_dbl(lower = 0, upper = 1, default = 1e-12, tags = c("predict", "prob_aggr"), depends = quote(prob_aggr == "log"))
109+
)
110+
super$initialize(innum, collect_multiplicity, id, param_set = param_set, param_vals = param_vals, prediction_type = "PredictionClassif", packages = "stats")
85111
}
86112
),
87113
private = list(
@@ -96,7 +122,13 @@ PipeOpClassifAvg = R6Class("PipeOpClassifAvg",
96122

97123
prob = NULL
98124
if (every(inputs, function(x) !is.null(x$prob))) {
99-
prob = weighted_matrix_sum(map(inputs, "prob"), weights)
125+
pv = self$param_set$get_values(tags = "prob_aggr")
126+
if (pv$prob_aggr == "mean") {
127+
prob = weighted_matrix_sum(map(inputs, "prob"), weights)
128+
} else { # prob_aggr == "log"
129+
epsilon = pv$prob_aggr_eps %??% 1e-12
130+
prob = weighted_matrix_logpool(map(inputs, "prob"), weights, epsilon = epsilon)
131+
}
100132
} else if (every(inputs, function(x) !is.null(x$response))) {
101133
prob = weighted_factor_mean(map(inputs, "response"), weights, lvls)
102134
} else {

R/PipeOpEnsemble.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,25 @@ weighted_matrix_sum = function(matrices, weights) {
178178
accmat
179179
}
180180

181+
# Weighted log-opinion pool (geometric) aggregation of probability matrices
182+
# Rows = samples, columns = classes. Each matrix must have the same shape.
183+
# @param matrices list of matrices: per-learner probabilities
184+
# @param weights numeric: weights, same length as `matrices` (assumed to sum to 1 upstream)
185+
# @param epsilon numeric(1): small positive constant to clamp probabilities before log
186+
# @return matrix: row-normalized aggregated probabilities (same shape as inputs)
187+
weighted_matrix_logpool = function(matrices, weights, epsilon = 1e-12) {
188+
assert_list(matrices, types = "matrix", min.len = 1)
189+
assert_numeric(weights, len = length(matrices), any.missing = FALSE, finite = TRUE)
190+
assert_number(epsilon, lower = 0, upper = 1)
191+
acc = weights[1] * log(pmax(matrices[[1]], epsilon))
192+
for (idx in seq_along(matrices)[-1]) {
193+
acc = acc + weights[idx] * log(pmax(matrices[[idx]], epsilon))
194+
}
195+
P = exp(acc)
196+
sweep(P, 1L, rowSums(P), "/")
197+
}
198+
199+
181200
# For a set of n `factor` vectors each of length l with the same k levels and a
182201
# numeric weight vector of length n, returns a matrix of dimension l times k.
183202
# Each cell contains the weighted relative frequency of the respective factor

0 commit comments

Comments
 (0)