|
| 1 | +#' @title Customizable Information Printer |
| 2 | +#' |
| 3 | +#' @usage NULL |
| 4 | +#' @name mlr_pipeops_info |
| 5 | +#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOp`] |
| 6 | +#' |
| 7 | +#' @description |
| 8 | +#' `PipeOpInfo` prints its input to the console or a logger in a customizable way. |
| 9 | +#' Users can define how specific object classes should be displayed using custom printer functions. |
| 10 | +#' |
| 11 | +#' @section Construction: |
| 12 | +#' ``` |
| 13 | +#' PipeOpInfo$new(id = "info", collect_multiplicity = FALSE, log_target = "lgr::mlr3/mlr3pipelines::info") |
| 14 | +#' ``` |
| 15 | +#' * `id` :: `character(1)`\cr |
| 16 | +#' Identifier of resulting object, default "info" |
| 17 | +#' * `printer` :: `list` \cr |
| 18 | +#' Optional mapping from object classes to printer functions. Custom functions override default printer-functions. |
| 19 | +#' * `collect_multiplicity` :: `logical(1)`\cr |
| 20 | +#' If `TRUE`, the input is a [`Multiplicity`] collecting channel. [`Multiplicity`] input/output is accepted and the members are aggregated. |
| 21 | +#' * `log_target` :: `character(1)`\cr |
| 22 | +#' Specifies how the input object is printed to the console. By default it is |
| 23 | +#' directed to a logger, whose address can be customized using the form |
| 24 | +#' `<output>::<argument1>::<argument2>`. Otherwise it can be printed |
| 25 | +#' as "message", "warning" or "cat". When set to "none", no customized |
| 26 | +#' information about the object will be printed. |
| 27 | +#' |
| 28 | +#' @section Input and Output Channels: |
| 29 | +#' `PipeOpInfo` has one input channel called "input", it can take any type of input (`*`). |
| 30 | +#' `PipeOpInfo` has one output channel called "output", it can take any type of output (`*`). |
| 31 | +#' |
| 32 | +#' @section State: |
| 33 | +#' The `$state` is left empty (`list()`). |
| 34 | +#' |
| 35 | +#' @section Internals: |
| 36 | +#' `PipeOpInfo` forwards its input unchanged, but prints information about it |
| 37 | +#' depending on the `printer` and `log_target` settings. |
| 38 | +#' |
| 39 | +#' @section Fields: |
| 40 | +#' Fields inherited from `PipeOp`, as well as: |
| 41 | +#' * `printer` :: `list`\cr |
| 42 | +#' Mapping of object classes to printer functions. Includes printer-specifications for `Task`, `Prediction`, `NULL`. Otherwise object is printed as is. |
| 43 | +#' * `log_target` :: `character(1)` \cr |
| 44 | +#' Specifies current output target. |
| 45 | +#' |
| 46 | +#' @section Methods: |
| 47 | +#' Only methods inherited from [`PipeOp`]. |
| 48 | +#' |
| 49 | +#' @examples |
| 50 | +#' library("mlr3") |
| 51 | +#' |
| 52 | +#' poinfo = po("info") |
| 53 | +#' poinfo$train(list(tsk("mtcars"))) |
| 54 | +#' poinfo$predict(list(tsk("mtcars"))) |
| 55 | +#' |
| 56 | +#' # Specify customized console output for Task-objects |
| 57 | +#' poinfo = po("info", log_target = "cat", |
| 58 | +#' printer = list(Task = function(x) list(head_data = head(x$data()), nrow = nrow(x$data()))) |
| 59 | +#' ) |
| 60 | +#' |
| 61 | +#' poinfo$train(list(tsk("iris"))) |
| 62 | +#' poinfo$predict(list(tsk("iris"))) |
| 63 | +#' |
| 64 | +#' @family PipeOps |
| 65 | +#' @template seealso_pipeopslist |
| 66 | +#' @include PipeOp.R |
| 67 | +#' @export |
| 68 | +#' |
| 69 | +#' |
| 70 | + |
| 71 | +PipeOpInfo = R6Class("PipeOpInfo", |
| 72 | + inherit = PipeOp, |
| 73 | + public = list( |
| 74 | + initialize = function(id = "info", printer = NULL, collect_multiplicity = FALSE, log_target = "lgr::mlr3/mlr3pipelines::info", param_vals = list()) { |
| 75 | + assertString(log_target, pattern = "^(cat|none|warning|message|lgr::[^:]+::[^:]+)$") |
| 76 | + inouttype = "*" |
| 77 | + if (collect_multiplicity) { |
| 78 | + inouttype = sprintf("[%s]", inouttype) |
| 79 | + } |
| 80 | + super$initialize(id, param_vals = param_vals, |
| 81 | + input = data.table(name = "input", train = inouttype, predict = inouttype), |
| 82 | + output = data.table(name = "output", train = inouttype, predict = inouttype) |
| 83 | + #tag = "debug" |
| 84 | + ) |
| 85 | + original_printer = list( |
| 86 | + Task = crate(function(x) { |
| 87 | + row_preview = head(x$row_ids, 10L) |
| 88 | + col_preview = head(c(x$target_names, x$feature_names), 10L) |
| 89 | + data_preview = x$data(rows = row_preview, cols = col_preview) |
| 90 | + list( |
| 91 | + task = x, |
| 92 | + data_preview = data_preview |
| 93 | + ) |
| 94 | + }), |
| 95 | + Prediction = crate(function(x) { |
| 96 | + tryCatch(list(prediction = x, score = x$score()), error = function(e) {list(prediction = x)}) |
| 97 | + }), |
| 98 | + `NULL` = crate(function(x) "NULL"), |
| 99 | + default = crate(function(x) x) |
| 100 | + ) |
| 101 | + private$.printer = insert_named(original_printer, printer) |
| 102 | + private$.log_target = log_target |
| 103 | + } |
| 104 | + ), |
| 105 | + active = list( |
| 106 | + printer = function(rhs) { |
| 107 | + if (!missing(rhs)) stop("printer is read only.") |
| 108 | + private$.printer |
| 109 | + }, |
| 110 | + log_target = function(rhs) { |
| 111 | + if (!missing(rhs)) stop("log_target is read only.") |
| 112 | + private$.log_target |
| 113 | + } |
| 114 | + ), |
| 115 | + private = list( |
| 116 | + .printer = NULL, |
| 117 | + .log_target = NULL, |
| 118 | + .output = function(inputs, stage) { |
| 119 | + input_class = class(inputs[[1]]) |
| 120 | + leftmost_class = |
| 121 | + if (any(input_class %in% names(private$.printer))) { |
| 122 | + input_class[input_class %in% names(private$.printer)][[1]] |
| 123 | + } else { |
| 124 | + "default" |
| 125 | + } |
| 126 | + if (!("default" %in% names(private$.printer))) { |
| 127 | + stop("Object-class was not found and no default printer is available.") |
| 128 | + } |
| 129 | + specific_printer = private$.printer[[leftmost_class]] |
| 130 | + log_target_split = strsplit(private$.log_target, "::")[[1]] |
| 131 | + stage_string = sprintf("Object passing through PipeOp %s - %s", self$id, stage) |
| 132 | + print_string = utils::capture.output({ |
| 133 | + cat(stage_string, "\n\n") |
| 134 | + specific_printer(inputs[[1]]) |
| 135 | + }) |
| 136 | + message_text = paste(print_string, collapse = "\n") |
| 137 | + if (log_target_split[[1]] == "lgr") { |
| 138 | + logger = lgr::get_logger(log_target_split[[2]]) |
| 139 | + log_level = log_target_split[[3]] |
| 140 | + logger$log(log_level, msg = message_text) |
| 141 | + } else if (private$.log_target == "cat") { |
| 142 | + cat(message_text) |
| 143 | + } else if (private$.log_target == "message") { |
| 144 | + message(message_text) |
| 145 | + } else if (private$.log_target == "warning") { |
| 146 | + warning(message_text) |
| 147 | + } else if (private$.log_target == "none") { |
| 148 | + } else { |
| 149 | + stopf("Invalid log_target '%s'.", private$.log_target) |
| 150 | + } |
| 151 | + }, |
| 152 | + .train = function(inputs, stage = "Training") { |
| 153 | + self$state = list() |
| 154 | + private$.output(inputs, stage) |
| 155 | + inputs |
| 156 | + }, |
| 157 | + .predict = function(inputs, stage = "Prediction") { |
| 158 | + private$.output(inputs, stage) |
| 159 | + inputs |
| 160 | + }, |
| 161 | + .additional_phash_input = function() { |
| 162 | + list(printer = self$printer, log_target = self$log_target) |
| 163 | + } |
| 164 | + ) |
| 165 | +) |
| 166 | + |
| 167 | +mlr_pipeops$add("info", PipeOpInfo) |
0 commit comments