Skip to content
Open
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.3.2
VignetteBuilder: knitr
Collate:
'DataBackendJoin.R'
'DataBackendMultiCbind.R'
'Graph.R'
'GraphLearner.R'
'mlr_pipeops.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
export("%>>!%")
export("%>>%")
export(DataBackendJoin)
export(DataBackendMultiCbind)
export(Graph)
export(GraphLearner)
export(LearnerClassifAvg)
Expand Down
160 changes: 160 additions & 0 deletions R/DataBackendJoin.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@


#' @export
DataBackendJoin = R6Class("DataBackendJoin", inherit = DataBackend, cloneable = FALSE,
public = list(
initialize = function(b1, b2, type, by_b1 = NULL, by_b2 = NULL, b1_index_colname = NULL, b2_index_colname = NULL) {
assert_backend(b1)
assert_backend(b2)

if ("data.table" %nin% intersect(b1$data_formats, b2$data_formats)) {
stop("DataBackendJoin currently only supports DataBackends that support 'data.table' format.")
}

assert_choice(type, c("left", "right", "outer", "inner"))

colnames_b1 = b1$colnames
colnames_b2 = b2$colnames
allcolnames = union(colnames_b1, colnames_b2)

assert_choice(by_b1, colnames_b1, null.ok = TRUE)
assert_choice(by_b2, colnames_b2, null.ok = TRUE)

assert_string(b1_index_colname, null.ok = TRUE)
assert_string(b2_index_colname, null.ok = TRUE)

if (!is.null(b1_index_colname) && b1_index_colname %in% setdiff(allcolnames, b1$primary_key)) stopf("b1_index_colname '%s' already a non-primary-key column in b1 or b2.", b1_index_colname)
if (!is.null(b2_index_colname) && b2_index_colname %in% setdiff(allcolnames, b2$primary_key)) stopf("b2_index_colname '%s' already a non-primary-key column in b2 or b2.", b2_index_colname)
if (!is.null(b1_index_colname) && !is.null(b2_index_colname) && b1_index_colname == b2_index_colname) stop("b1_index_colname and b2_index_colname must be different, but are both '%s'.", b1_index_colname)

colnames = unique(c(allcolnames, b1_index_colname, b2_index_colname))

rownames_b1 = b1$rownames
rownames_b2 = b2$rownames

joinby_b1 = if (is.null(by_b1)) rownames_b1 else b1$data(rownames_b1, by_b1, data_format = "data.table")[[1]]
joinby_b2 = if (is.null(by_b2)) rownames_b2 else b2$data(rownames_b2, by_b2, data_format = "data.table")[[1]]

index_table = merge(data.table(rownames_b1, joinby_b1), data.table(rownames_b2, joinby_b2), by.x = "joinby_b1", by.y = "joinby_b2",
all.x = type %in% c("left", "outer"), all.y = type %in% c("right", "outer"), sort = FALSE, allow.cartesian = TRUE)

set(index_table, , "joinby_b1", NULL)

pk = "..row_id"
index = 0
while (pk %in% allcolnames) {
index = index + 1
pk = paste0("..row_id.", index)
}

super$initialize(list(
b1 = b1, b2 = b2,
colnames_b1 = setdiff(colnames_b1, colnames_b2),
allcolnames = unique(c(colnames_b1, colnames_b2, b1_index_colname, b2_index_colname, pk)),
index_table = index_table,
b1_index_colname = b1_index_colname,
b2_index_colname = b2_index_colname,
pk = pk,
aux_hash = calculate_hash(by_b1, by_b2, type, b1_index_colname, b2_index_colname)
), primary_key = pk, data_formats = "data.table")
},

data = function(rows, cols, data_format = "data.table") {
d = private$.data
rows = rows[inrange(rows, 1, nrow(d$index_table))]
indices = d$index_table[rows]
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
indices[!is.na(rownames_b1), b1_index := seq_len(length(b1_rows))]
indices[!is.na(rownames_b2), b2_index := seq_len(length(b2_rows))]
b1_index = indices[, b1_index]
b2_index = indices[, b2_index]

data = d$b2$data(b2_rows, cols, data_format = "data.table")[b2_index]
remainingcols = intersect(cols, d$colnames_b1)
if (length(remainingcols)) {
data = cbind(data, d$b1$data(b1_rows, cols, data_format = "data.table")[b1_index])
}
setkeyv(data, NULL)
if (d$pk %in% cols) {
set(data, , d$pk, rows)
}
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
rownames_b2 = indices$rownames_b2
set(data, , d$b2_index_colname, rownames_b2)
}
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
rownames_b1 = indices$rownames_b1
set(data, ,d$b1_index_colname, rownames_b1)
}
data[, intersect(cols, names(data)), with = FALSE]
},

head = function(n = 6L) {
rows = first(self$rownames, n)
self$data(rows = rows, cols = self$colnames)
},
distinct = function(rows, cols, na_rm = TRUE) {
d = private$.data
indices = d$index_table[rows]
rownames_b1 = rownames_b2 = NULL
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
d2 = private$.data$b2$distinct(rows = b2_rows, cols = cols, na_rm = na_rm)
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
d2[[d$b2_index_colname]] = if (na_rm) unique(b2_rows) else unique(indices$rownames_b2)
}
d1 = private$.data$b1$distinct(rows = b1_rows, cols = setdiff(cols, names(d2)), na_rm = na_rm)
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
d1[[d$b1_index_colname]] = if (na_rm) unique(b1_rows) else unique(indices$rownames_b1)
}

if (!na_rm && length(b1_rows) < length(rows)) {
d1 = map(d1, function(x) if (any(is.na(x))) x else c(x, NA))
}
if (!na_rm && length(b2_rows) < length(rows)) {
d2 = map(d2, function(x) if (any(is.na(x))) x else c(x, NA))
}
res = c(d1, d2)
if (d$pk %in% cols) {
res[[d$pk]] = unique(rows)
}

res[match(cols, names(res), nomatch = 0)]
},
missings = function(rows, cols) {
d = private$.data
indices = d$index_table[rows]
rownames_b1 = rownames_b2 = NULL
b1_rows = indices[!is.na(rownames_b1), rownames_b1]
b2_rows = indices[!is.na(rownames_b2), rownames_b2]
m2 = private$.data$b2$missings(b2_rows, cols)
if (!is.null(d$b2_index_colname) && d$b2_index_colname %in% cols) {
m2[d$b2_index_colname] = 0L
}
m1 = private$.data$b1$missings(b1_rows, setdiff(cols, names(m2)))
if (!is.null(d$b1_index_colname) && d$b1_index_colname %in% cols) {
m1[d$b1_index_colname] = 0L
}
m1 = m1 + length(rows) - length(b1_rows)
m2 = m2 + length(rows) - length(b2_rows)
res = c(m1, m2)
if (d$pk %in% cols) {
res[d$pk] = 0L
}
res[match(cols, names(res), nomatch = 0)]
}
),
active = list(
rownames = function() seq_len(nrow(private$.data$index_table)),
colnames = function() private$.data$allcolnames,
nrow = function() nrow(private$.data$index_table),
ncol = function() length(private$.data$allcolnames)
),
private = list(
.calculate_hash = function() {
d = private$.data
calculate_hash(d$b1$hash, d$b2$hash,d$aux_hash)
}
)
)
134 changes: 134 additions & 0 deletions R/DataBackendMultiCbind.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@


#' @export
DataBackendMultiCbind = R6Class("DataBackendMultiCbind", inherit = DataBackend, cloneable = FALSE,
public = list(
initialize = function(bs) {
assert_list(bs, min.len = 1)
lapply(bs, assert_backend)

formats = Reduce(intersect, map(bs, "data_formats"))

private$.colnames = unique(unlist(map(bs, "colnames")))

# primary key: if all backends have the same pk, just use that one.
otherpk = unique(unlist(map(bs, "primary_key")))
if (length(otherpk) == 1) {
pk = otherpk
} else {
# otherwise: introduce a new primary key that is completely different from the previous ones.
pk = "..row_id"
index = 0
while (pk %in% private$.colnames) {
index = index + 1
pk = paste0("..row_id.", index)
}
private$.colnames = c(private$.colnames, pk)
}

super$initialize(list(bs = rev(bs)), pk, formats)
},
data = function(rows, cols, data_format = "data.table") {
bs = private$.data$bs

urows = unique(rows)

datas = list()
pks = character(length(bs))
include_pk = logical(length(bs))
cols_remaining = cols
allrows = list()
for (i in seq_along(bs)) {
## Not doing 'if (length(cols_remaining)) break' because there could still be tables remaining that add rows
pk = bs[[i]]$primary_key
pks[[i]] = pk
include_pk[[i]] = pk %in% cols_remaining
if (include_pk[[i]]) {
datas[[i]] = bs[[i]]$data(urows, cols_remaining, data_format = data_format)
cols_remaining = setdiff(cols_remaining, colnames(datas[[i]]))
} else {
datas[[i]] = bs[[i]]$data(urows, c(pk, cols_remaining), data_format = data_format)
cols_remaining = setdiff(cols_remaining, colnames(datas[[i]])[-1])
}
allrows[[i]] = datas[[i]][[pk]]
}
presentrows = unique(unlist(allrows))
join = list(presentrows)
result = do.call(cbind, pmap(list(datas, pks, include_pk), function(data, pk, include) {
if (include) {
result = data[join, on = pk, nomatch = NA]
set(result, result[[pk]] %nin% data[[pk]], pk, NA)
} else {
data[join, -pk, on = pk, with = FALSE, nomatch = NA]
}
}))
sbk = self$primary_key

set(result, , sbk, presentrows)
join = list(rows)
result[join, intersect(cols, colnames(result)), with = FALSE, on = sbk, nomatch = NULL]
},
head = function(n = 6L) {
rows = head(self$rownames, n)
self$data(rows = rows, cols = self$colnames)
},
distinct = function(rows, cols, na_rm = TRUE) {
bs = private$.data$bs
getpk = self$primary_key %in% cols
reslist = list()
remaining_cols = cols
if (!na_rm || getpk) {
rows = intersect(rows, self$rownames)
}
for (i in seq_along(bs)) {
if (!length(remaining_cols)) break
reslist[[i]] = bs[[i]]$distinct(rows = rows, cols = cols, na_rm = na_rm)
remaining_cols = setdiff(remaining_cols, names(reslist[[i]]))
if (!na_rm && !all(rows %in% bs[[i]]$rownames)) {
reslist[[i]] = map(reslist[[i]], function(x) if (any(is.na(x))) x else c(x, NA))
}
}
result = unlist(reslist, recursive = FALSE)
if (getpk) {
result[[self$primary_key]] = rows
}
result[match(cols, names(result), nomatch = 0)]
},
missings = function(rows, cols) {
rows = rows[rows %in% self$rownames]
bs = private$.data$bs
getpk = self$primary_key %in% cols
reslist = list()
remaining_cols = cols
for (i in seq_along(bs)) {
if (!length(remaining_cols)) break
missingrows = sum(rows %nin% bs[[i]]$rownames)
reslist[[i]] = bs[[i]]$missings(rows, remaining_cols) + missingrows
remaining_cols = setdiff(remaining_cols, names(reslist[[i]]))
}
result = unlist(reslist)
if (self$primary_key %in% cols) {
result[[self$primary_key]] = 0L
}
result[match(cols, names(result), nomatch = 0)]
}
),
active = list(
rownames = function() {
if (is.null(private$.rownames_cache)) private$.rownames_cache = unique(unlist(rev(map(private$.data$bs, "rownames"))))
private$.rownames_cache
},
colnames = function() {
private$.colnames
},
nrow = function() length(self$rownames),
ncol = function() length(self$colnames)
),
private = list(
.rownames_cache = NULL,
.colnames = NULL,
.calculate_hash = function() {
do.call(calculate_hash, private$.data$bs)
}
)
)
Loading