11# ' Barebone Partial Dependence Function
2- # '
2+ # '
33# ' Workhorse of the package, thus optimized for speed.
4- # '
4+ # '
55# ' @noRd
66# ' @keywords internal
7- # '
7+ # '
88# ' @inheritParams partial_dep
99# ' @param grid A vector, data.frame or matrix of grid values consistent with `v` and `X`.
1010# ' @param compress_X If `X` has a single non-`v` column: should duplicates be removed
1111# ' and compensated via case weights? Default is `TRUE`.
12- # ' @param compress_grid Should duplicates in `grid` be removed and PDs mapped back to
12+ # ' @param compress_grid Should duplicates in `grid` be removed and PDs mapped back to
1313# ' the original grid index? Default is `TRUE`.
14- # ' @returns
15- # ' A matrix of partial dependence values (one column per prediction dimension,
14+ # ' @returns
15+ # ' A matrix of partial dependence values (one column per prediction dimension,
1616# ' one row per grid row, in the same order as `grid`).
1717pd_raw <- function (
1818 object ,
@@ -23,8 +23,7 @@ pd_raw <- function(
2323 w = NULL ,
2424 compress_X = TRUE ,
2525 compress_grid = TRUE ,
26- ...
27- ) {
26+ ... ) {
2827 # Try different compressions
2928 if (compress_X && length(v ) == ncol(X ) - 1L ) {
3029 # Removes duplicates in X[, not_v] and compensates via w
@@ -37,13 +36,19 @@ pd_raw <- function(
3736 cmp_grid <- .compress_grid(grid = grid )
3837 grid <- cmp_grid [[" grid" ]]
3938 }
40-
39+
4140 # Now, the real work
4241 pred <- ice_raw(
43- object , v = v , X = X , grid = grid , pred_fun = pred_fun , pred_only = TRUE , ...
42+ object ,
43+ v = v ,
44+ X = X ,
45+ grid = grid ,
46+ pred_fun = pred_fun ,
47+ pred_only = TRUE ,
48+ ...
4449 )
4550 pd <- wrowmean(pred , ngroups = NROW(grid ), w = w )
46-
51+
4752 # Map back to grid order
4853 if (compress_grid && ! is.null(reindex <- cmp_grid [[" reindex" ]])) {
4954 return (pd [reindex , , drop = FALSE ])
@@ -52,44 +57,49 @@ pd_raw <- function(
5257}
5358
5459# ' Barebone ICE Function
55- # '
60+ # '
5661# ' Part of the workhorse function `pd_raw()`, thus optimized for speed.
57- # '
62+ # '
5863# ' @noRd
5964# ' @keywords internal
60- # '
65+ # '
6166# ' @inheritParams pd_raw
6267# ' @param pred_only Logical flag determining the output mode. If `TRUE`, just
6368# ' predictions. Otherwise, a list with two elements: `pred` (predictions)
64- # ' and `grid_pred` (the corresponding grid values in the same mode as the input,
69+ # ' and `grid_pred` (the corresponding grid values in the same mode as the input,
6570# ' but replicated over `X`).
66- # ' @returns
71+ # ' @returns
6772# ' Either a vector/matrix of predictions or a list with predictions and grid.
6873ice_raw <- function (
69- object , v , X , grid , pred_fun = stats :: predict , pred_only = TRUE , ...
70- ) {
74+ object ,
75+ v ,
76+ X ,
77+ grid ,
78+ pred_fun = stats :: predict ,
79+ pred_only = TRUE ,
80+ ... ) {
7181 D1 <- length(v ) == 1L
7282 n <- nrow(X )
7383 n_grid <- NROW(grid )
74-
84+
7585 # Explode everything to n * n_grid rows
7686 X_pred <- rep_rows(X , rep.int(seq_len(n ), n_grid ))
7787 if (D1 ) {
7888 grid_pred <- rep(grid , each = n )
7989 } else {
8090 grid_pred <- rep_rows(grid , rep_each(n_grid , n ))
8191 }
82-
92+
8393 # Vary v
8494 if (D1 && is.data.frame(X_pred )) {
85- X_pred [[v ]] <- grid_pred # [, v] <- slower if df
95+ X_pred [[v ]] <- grid_pred # [, v] <- slower if df
8696 } else {
8797 X_pred [, v ] <- grid_pred
8898 }
89-
99+
90100 # Calculate matrix/vector of predictions
91101 pred <- prepare_pred(pred_fun(object , X_pred , ... ))
92-
102+
93103 if (pred_only ) {
94104 return (pred )
95105 }
@@ -99,31 +109,31 @@ ice_raw <- function(
99109# Helper functions used only within pd_raw()
100110
101111# ' Compresses X
102- # '
112+ # '
103113# ' @description
104- # ' Internal function to remove duplicated rows in `X` based on columns not in `v`.
105- # ' Compensation is done by summing corresponding case weights `w`.
114+ # ' Internal function to remove duplicated rows in `X` based on columns not in `v`.
115+ # ' Compensation is done by summing corresponding case weights `w`.
106116# ' Currently implemented only for the case when there is a single non-`v` column in `X`.
107- # ' Can later be generalized to multiple columns via [paste()].
108- # '
117+ # ' Can later be generalized to multiple columns via [paste()].
118+ # '
109119# ' Notes:
110120# ' - This function is important for interaction calculations.
111121# ' - The initial check for having a single non-`v` column is very cheap.
112- # '
122+ # '
113123# ' @noRd
114124# ' @keywords internal
115- # '
125+ # '
116126# ' @inheritParams pd_raw
117127# ' @returns A list with `X` and `w`, potentially compressed.
118128.compress_X <- function (X , v , w = NULL ) {
119129 not_v <- setdiff(colnames(X ), v )
120130 if (length(not_v ) != 1L ) {
121- return (list (X = X , w = w )) # No optimization implemented for this case
131+ return (list (X = X , w = w )) # No optimization implemented for this case
122132 }
123133 x_not_v <- if (is.data.frame(X )) X [[not_v ]] else X [, not_v ]
124134 X_dup <- duplicated(x_not_v )
125135 if (! any(X_dup )) {
126- return (list (X = X , w = w )) # No optimization done
136+ return (list (X = X , w = w )) # No optimization done
127137 }
128138
129139 # Compensate via w
@@ -135,22 +145,22 @@ ice_raw <- function(
135145 x_not_v <- match(x_not_v , x_not_v [! X_dup ])
136146 }
137147 list (
138- X = X [! X_dup , , drop = FALSE ],
148+ X = X [! X_dup , , drop = FALSE ],
139149 w = c(rowsum(w , group = x_not_v , reorder = FALSE ))
140150 )
141151}
142152
143153# ' Compresses Grid
144- # '
145- # ' Internal function used to remove duplicated grid rows. Re-indexing to original grid
154+ # '
155+ # ' Internal function used to remove duplicated grid rows. Re-indexing to original grid
146156# ' rows needs to be later, but this function provides the re-index vector to do so.
147157# ' Further note that checking for uniqueness can be costly for higher-dimensional grids.
148- # '
158+ # '
149159# ' @noRd
150160# ' @keywords internal
151- # '
161+ # '
152162# ' @inheritParams pd_raw
153- # ' @returns
163+ # ' @returns
154164# ' A list with `grid` (possibly compressed) and the optional `reindex` vector
155165# ' used to map compressed grid values back to the original grid rows. The original
156166# ' grid equals the compressed grid at indices `reindex`.
@@ -161,14 +171,14 @@ ice_raw <- function(
161171 return (list (grid = grid , reindex = NULL ))
162172 }
163173 out <- list (grid = ugrid )
164- if (NCOL(grid ) > = 2L ) { # Non-vector case
165- grid <- do.call(paste , c(as.data.frame(grid ), sep = " _:_" ))
166- ugrid <- do.call(paste , c(as.data.frame(ugrid ), sep = " _:_" ))
174+ if (NCOL(grid ) > = 2L ) { # Non-vector case (see merge())
175+ # can we drop the as.data.frame()? I think yes
176+ grid <- do.call(paste , c(as.data.frame(grid ), sep = " \r " ))
177+ ugrid <- do.call(paste , c(as.data.frame(ugrid ), sep = " \r " ))
167178 if (anyDuplicated(ugrid )) {
168- stop(" String '_:_' found in grid values at unlucky position." )
179+ stop(" Carriage return found in grid values at unlucky position." )
169180 }
170181 }
171182 out [[" reindex" ]] <- match(grid , ugrid )
172183 out
173184}
174-
0 commit comments