Skip to content

Refactor row_callback system on type.frame #15

@dereckmezquita

Description

@dereckmezquita
PersonFrame <- type.frame(
    frame = data.frame,
    col_types = list(
        id = integer,
        name = character,
        age = numeric,
        is_student = logical,
        email = function(x) grepl("^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", x) # functions are applied individually to each element; not the whole column
    ),
    freeze_n_cols = FALSE,
    row_callback = function(row) {
        if (row$age >= 40) {
            return(sprintf("Age must be less than 40 (got %d)", row$age))
        }
        if (row$name == "Yanice") {
            return("Name cannot be 'Yanice'")
        }
        return(TRUE)
    },
    allow_na = FALSE,
    on_violation = "error"
)
type.frame <- function(
    frame,
    col_types,
    freeze_n_cols = TRUE,
    row_callback = NULL,
    allow_na = TRUE,
    on_violation = c("error", "warning", "silent")
) {
    on_violation <- match.arg(on_violation)

    creator <- function(...) {
        df <- frame(...)
        errors <- list()

        # check for missing columns
        for (col_name in names(col_types)) {
            if (!(col_name %in% names(df))) {
                errors <- append(errors, sprintf("Required column '%s' is missing", col_name))
            }
        }

        # missing cols check
        if (length(errors) > 0) {
            errors <- append(errors, sprintf("Missing columns: %s", paste(names(col_types), collapse = ", ")))
        }

        # number of columns check
        if (freeze_n_cols && ncol(df) != length(col_types)) {
            errors <- append(errors, sprintf("Number of columns must match: expected %d, got %d", length(col_types), ncol(df)))
        }

        # na check
        if (!allow_na && any(is.na(df))) {
            na_cols <- names(df)[sapply(df, function(x) return(any(is.na(x))) )]
            errors <- append(errors, sprintf("NA values found in column(s): %s", paste(na_cols, collapse = ", ")))
        }

        # go over each column and check the types
        for (col_name in names(col_types)) {
            curr_col_data <- df[[col_name]]
            curr_col_type <- col_types[[col_name]]

            # check for enum types
            if (inherits(curr_col_type, "enum_generator")) {
                enum_errors <- character(0)

                # type.frame generator function allows one to create a new frame object
                # we have to iterate over every value in the column and check it's typings
                # try catch to apply the enum generator function to the values provided for each record
                for (i in seq_along(curr_col_data)) {
                    tryCatch({
                        curr_col_data[[i]] <- curr_col_type(curr_col_data[[i]])
                    }, error = function(e) {
                        enum_errors <- append(enum_errors, sprintf("Row %d: %s", i, e$message))
                    })
                }

                if (length(enum_errors) > 0) {
                    errors <- append(errors, enum_errors)
                }
            } else {
                # For non-enum columns, use the original type
                error <- validate_property(col_name, df[[col_name]], col_type)
                if (!is.null(error)) {
                    errors <- append(errors, error)
                }
            }
        }

        # Process rows with callback
        if (!is.null(row_callback)) {
            for (i in seq_len(nrow(df))) {
                row <- df[i, , drop = FALSE]
                tryCatch({
                    result <- row_callback(row)
                    if (!isTRUE(result)) {
                        errors <- append(errors, sprintf("Row %d failed validation: %s", i, as.character(result)))
                    }
                }, error = function(e) {
                    errors <- append(errors, sprintf("Error processing row %d: %s", i, e$message))
                })
            }
        }

        # Handle all collected errors
        if (length(errors) > 0) {
            error_message <- paste("Validation errors:", paste(errors, collapse = "\n  "), sep = "\n  ")
            handle_violation(error_message, on_violation)
        }

        # Create the typed data frame
        structure(
            df,
            class = c("typed_frame", class(df)),
            col_types = col_types,
            freeze_n_cols = freeze_n_cols,
            row_callback = row_callback,
            allow_na = allow_na,
            on_violation = on_violation
        )
    }

    return(creator)
}

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions