Skip to content

Commit a0bf483

Browse files
authored
Merge pull request #135 from ModelOriented/fix-multi-column-key-separator
Replace multi-column separator
2 parents 3afaba0 + 78fd090 commit a0bf483

File tree

3 files changed

+58
-44
lines changed

3 files changed

+58
-44
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
- Changed Tidymodels example to probabilistic multiclass, see discussion in [#129](https://github.com/ModelOriented/hstats/issues/129).
66

7+
## Internals
8+
9+
- Use "\r" instead of "_:_" as separator to paste values in multi-column grids (similar to `merge()`). [#133](https://github.com/ModelOriented/hstats/issues/133).
10+
711
# hstats 1.2.1
812

913
## Usability

R/pd_raw.R

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#' Barebone Partial Dependence Function
2-
#'
2+
#'
33
#' Workhorse of the package, thus optimized for speed.
4-
#'
4+
#'
55
#' @noRd
66
#' @keywords internal
7-
#'
7+
#'
88
#' @inheritParams partial_dep
99
#' @param grid A vector, data.frame or matrix of grid values consistent with `v` and `X`.
1010
#' @param compress_X If `X` has a single non-`v` column: should duplicates be removed
1111
#' and compensated via case weights? Default is `TRUE`.
12-
#' @param compress_grid Should duplicates in `grid` be removed and PDs mapped back to
12+
#' @param compress_grid Should duplicates in `grid` be removed and PDs mapped back to
1313
#' the original grid index? Default is `TRUE`.
14-
#' @returns
15-
#' A matrix of partial dependence values (one column per prediction dimension,
14+
#' @returns
15+
#' A matrix of partial dependence values (one column per prediction dimension,
1616
#' one row per grid row, in the same order as `grid`).
1717
pd_raw <- function(
1818
object,
@@ -23,8 +23,7 @@ pd_raw <- function(
2323
w = NULL,
2424
compress_X = TRUE,
2525
compress_grid = TRUE,
26-
...
27-
) {
26+
...) {
2827
# Try different compressions
2928
if (compress_X && length(v) == ncol(X) - 1L) {
3029
# Removes duplicates in X[, not_v] and compensates via w
@@ -37,13 +36,19 @@ pd_raw <- function(
3736
cmp_grid <- .compress_grid(grid = grid)
3837
grid <- cmp_grid[["grid"]]
3938
}
40-
39+
4140
# Now, the real work
4241
pred <- ice_raw(
43-
object, v = v, X = X, grid = grid, pred_fun = pred_fun, pred_only = TRUE, ...
42+
object,
43+
v = v,
44+
X = X,
45+
grid = grid,
46+
pred_fun = pred_fun,
47+
pred_only = TRUE,
48+
...
4449
)
4550
pd <- wrowmean(pred, ngroups = NROW(grid), w = w)
46-
51+
4752
# Map back to grid order
4853
if (compress_grid && !is.null(reindex <- cmp_grid[["reindex"]])) {
4954
return(pd[reindex, , drop = FALSE])
@@ -52,44 +57,49 @@ pd_raw <- function(
5257
}
5358

5459
#' Barebone ICE Function
55-
#'
60+
#'
5661
#' Part of the workhorse function `pd_raw()`, thus optimized for speed.
57-
#'
62+
#'
5863
#' @noRd
5964
#' @keywords internal
60-
#'
65+
#'
6166
#' @inheritParams pd_raw
6267
#' @param pred_only Logical flag determining the output mode. If `TRUE`, just
6368
#' predictions. Otherwise, a list with two elements: `pred` (predictions)
64-
#' and `grid_pred` (the corresponding grid values in the same mode as the input,
69+
#' and `grid_pred` (the corresponding grid values in the same mode as the input,
6570
#' but replicated over `X`).
66-
#' @returns
71+
#' @returns
6772
#' Either a vector/matrix of predictions or a list with predictions and grid.
6873
ice_raw <- function(
69-
object, v, X, grid, pred_fun = stats::predict, pred_only = TRUE, ...
70-
) {
74+
object,
75+
v,
76+
X,
77+
grid,
78+
pred_fun = stats::predict,
79+
pred_only = TRUE,
80+
...) {
7181
D1 <- length(v) == 1L
7282
n <- nrow(X)
7383
n_grid <- NROW(grid)
74-
84+
7585
# Explode everything to n * n_grid rows
7686
X_pred <- rep_rows(X, rep.int(seq_len(n), n_grid))
7787
if (D1) {
7888
grid_pred <- rep(grid, each = n)
7989
} else {
8090
grid_pred <- rep_rows(grid, rep_each(n_grid, n))
8191
}
82-
92+
8393
# Vary v
8494
if (D1 && is.data.frame(X_pred)) {
85-
X_pred[[v]] <- grid_pred # [, v] <- slower if df
95+
X_pred[[v]] <- grid_pred # [, v] <- slower if df
8696
} else {
8797
X_pred[, v] <- grid_pred
8898
}
89-
99+
90100
# Calculate matrix/vector of predictions
91101
pred <- prepare_pred(pred_fun(object, X_pred, ...))
92-
102+
93103
if (pred_only) {
94104
return(pred)
95105
}
@@ -99,31 +109,31 @@ ice_raw <- function(
99109
# Helper functions used only within pd_raw()
100110

101111
#' Compresses X
102-
#'
112+
#'
103113
#' @description
104-
#' Internal function to remove duplicated rows in `X` based on columns not in `v`.
105-
#' Compensation is done by summing corresponding case weights `w`.
114+
#' Internal function to remove duplicated rows in `X` based on columns not in `v`.
115+
#' Compensation is done by summing corresponding case weights `w`.
106116
#' Currently implemented only for the case when there is a single non-`v` column in `X`.
107-
#' Can later be generalized to multiple columns via [paste()].
108-
#'
117+
#' Can later be generalized to multiple columns via [paste()].
118+
#'
109119
#' Notes:
110120
#' - This function is important for interaction calculations.
111121
#' - The initial check for having a single non-`v` column is very cheap.
112-
#'
122+
#'
113123
#' @noRd
114124
#' @keywords internal
115-
#'
125+
#'
116126
#' @inheritParams pd_raw
117127
#' @returns A list with `X` and `w`, potentially compressed.
118128
.compress_X <- function(X, v, w = NULL) {
119129
not_v <- setdiff(colnames(X), v)
120130
if (length(not_v) != 1L) {
121-
return(list(X = X, w = w)) # No optimization implemented for this case
131+
return(list(X = X, w = w)) # No optimization implemented for this case
122132
}
123133
x_not_v <- if (is.data.frame(X)) X[[not_v]] else X[, not_v]
124134
X_dup <- duplicated(x_not_v)
125135
if (!any(X_dup)) {
126-
return(list(X = X, w = w)) # No optimization done
136+
return(list(X = X, w = w)) # No optimization done
127137
}
128138

129139
# Compensate via w
@@ -135,22 +145,22 @@ ice_raw <- function(
135145
x_not_v <- match(x_not_v, x_not_v[!X_dup])
136146
}
137147
list(
138-
X = X[!X_dup, , drop = FALSE],
148+
X = X[!X_dup, , drop = FALSE],
139149
w = c(rowsum(w, group = x_not_v, reorder = FALSE))
140150
)
141151
}
142152

143153
#' Compresses Grid
144-
#'
145-
#' Internal function used to remove duplicated grid rows. Re-indexing to original grid
154+
#'
155+
#' Internal function used to remove duplicated grid rows. Re-indexing to original grid
146156
#' rows needs to be later, but this function provides the re-index vector to do so.
147157
#' Further note that checking for uniqueness can be costly for higher-dimensional grids.
148-
#'
158+
#'
149159
#' @noRd
150160
#' @keywords internal
151-
#'
161+
#'
152162
#' @inheritParams pd_raw
153-
#' @returns
163+
#' @returns
154164
#' A list with `grid` (possibly compressed) and the optional `reindex` vector
155165
#' used to map compressed grid values back to the original grid rows. The original
156166
#' grid equals the compressed grid at indices `reindex`.
@@ -161,14 +171,14 @@ ice_raw <- function(
161171
return(list(grid = grid, reindex = NULL))
162172
}
163173
out <- list(grid = ugrid)
164-
if (NCOL(grid) >= 2L) { # Non-vector case
165-
grid <- do.call(paste, c(as.data.frame(grid), sep = "_:_"))
166-
ugrid <- do.call(paste, c(as.data.frame(ugrid), sep = "_:_"))
174+
if (NCOL(grid) >= 2L) { # Non-vector case (see merge())
175+
# can we drop the as.data.frame()? I think yes
176+
grid <- do.call(paste, c(as.data.frame(grid), sep = "\r"))
177+
ugrid <- do.call(paste, c(as.data.frame(ugrid), sep = "\r"))
167178
if (anyDuplicated(ugrid)) {
168-
stop("String '_:_' found in grid values at unlucky position.")
179+
stop("Carriage return found in grid values at unlucky position.")
169180
}
170181
}
171182
out[["reindex"]] <- match(grid, ugrid)
172183
out
173184
}
174-

tests/testthat/test_partial_dep.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,6 @@ test_that(".compress_grid() leaves grid unchanged if unique", {
594594
})
595595

596596
test_that(".compress_grid() can fail with very strange values", {
597-
g <- data.frame(X = c("", "", "_:_"), Y = c("_:_", "_:_", ""))
597+
g <- data.frame(X = c("", "", "\r"), Y = c("\r", "\r", ""))
598598
expect_error(.compress_grid(g))
599599
})

0 commit comments

Comments
 (0)