diff --git a/docs/source/user-guide/common-operations/expressions.rst b/docs/source/user-guide/common-operations/expressions.rst index e35234c32..6014c9d2e 100644 --- a/docs/source/user-guide/common-operations/expressions.rst +++ b/docs/source/user-guide/common-operations/expressions.rst @@ -60,6 +60,43 @@ examples for the and, or, and not operations. heavy_red_units = (col("color") == lit("red")) & (col("weight") > lit(42)) not_red_units = ~(col("color") == lit("red")) +Arrays +------ + +For columns that contain arrays of values, you can access individual elements of the array by index +using bracket indexing. This is similar to callling the function +:py:func:`datafusion.functions.array_element`, except that array indexing using brackets is 0 based, +similar to Python arrays and ``array_element`` is 1 based indexing to be compatible with other SQL +approaches. + +.. ipython:: python + + from datafusion import SessionContext, col + + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]}) + df.select(col("a")[0].alias("a0")) + + +.. warning:: + + Indexing an element of an array via ``[]`` starts at index 0 whereas + :py:func:`~datafusion.functions.array_element` starts at index 1. + +Structs +------- + +Columns that contain struct elements can be accessed using the bracket notation as if they were +Python dictionary style objects. This expects a string key as the parameter passed. + +.. ipython:: python + + ctx = SessionContext() + data = {"a": [{"size": 15, "color": "green"}, {"size": 10, "color": "blue"}]} + df = ctx.from_pydict(data) + df.select(col("a")["size"].alias("a_size")) + + Functions --------- diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 0e7d82e29..46b8fa1bd 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -30,6 +30,7 @@ import pandas as pd import polars as pl import pathlib + from typing import Callable from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr @@ -72,6 +73,9 @@ def __repr__(self) -> str: """ return self.df.__repr__() + def _repr_html_(self) -> str: + return self.df._repr_html_() + def describe(self) -> DataFrame: """Return the statistics for this DataFrame. @@ -539,3 +543,25 @@ def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any: Arrow PyCapsule object. """ return self.df.__arrow_c_stream__(requested_schema) + + def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame: + """Apply a function to the current DataFrame which returns another DataFrame. + + This is useful for chaining together multiple functions. For example:: + + def add_3(df: DataFrame) -> DataFrame: + return df.with_column("modified", lit(3)) + + def within_limit(df: DataFrame, limit: int) -> DataFrame: + return df.filter(col("a") < lit(limit)).distinct() + + df = df.transform(modify_df).transform(within_limit, 4) + + Args: + func: A callable function that takes a DataFrame as it's first argument + args: Zero or more arguments to pass to `func` + + Returns: + DataFrame: After applying func to the original dataframe. + """ + return func(self, *args) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 742f8e43d..7bea0289b 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,11 @@ from __future__ import annotations -from ._internal import expr as expr_internal, LogicalPlan +from ._internal import ( + expr as expr_internal, + LogicalPlan, + functions as functions_internal, +) from datafusion.common import NullTreatment, RexType, DataTypeMap from typing import Any, Optional import pyarrow as pa @@ -257,8 +261,17 @@ def __invert__(self) -> Expr: """Binary not (~).""" return Expr(self.expr.__invert__()) - def __getitem__(self, key: str) -> Expr: - """For struct data types, return the field indicated by ``key``.""" + def __getitem__(self, key: str | int) -> Expr: + """Retrieve sub-object. + + If ``key`` is a string, returns the subfield of the struct. + If ``key`` is an integer, retrieves the element in the array. Note that the + element index begins at ``0``, unlike `array_element` which begines at ``1``. + """ + if isinstance(key, int): + return Expr( + functions_internal.array_element(self.expr, Expr.literal(key + 1).expr) + ) return Expr(self.expr.__getitem__(key)) def __eq__(self, rhs: Any) -> Expr: diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 120fed819..4c701b24d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1023,7 +1023,7 @@ def array(*args: Expr) -> Expr: This is an alias for :py:func:`make_array`. """ - return make_array(args) + return make_array(*args) def range(start: Expr, stop: Expr, step: Expr) -> Expr: diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index c2a5f22ba..90954d09a 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import os +from typing import Any import pyarrow as pa from pyarrow.csv import write_csv @@ -970,3 +971,34 @@ def test_dataframe_export(df) -> None: except Exception: failed_convert = True assert failed_convert + + +def test_dataframe_transform(df): + def add_string_col(df_internal) -> DataFrame: + return df_internal.with_column("string_col", literal("string data")) + + def add_with_parameter(df_internal, value: Any) -> DataFrame: + return df_internal.with_column("new_col", literal(value)) + + df = df.transform(add_string_col).transform(add_with_parameter, 3) + + result = df.to_pydict() + + assert result["a"] == [1, 2, 3] + assert result["string_col"] == ["string data" for _i in range(0, 3)] + assert result["new_col"] == [3 for _i in range(0, 3)] + + +def test_dataframe_repr_html(df) -> None: + output = df._repr_html_() + + ref_html = """ + + + + +
abc
148
255
368
+ """ + + # Ignore whitespace just to make this test look cleaner + assert output.replace(" ", "") == ref_html.replace(" ", "") diff --git a/python/datafusion/tests/test_expr.py b/python/datafusion/tests/test_expr.py index 9071108cb..056d2ea03 100644 --- a/python/datafusion/tests/test_expr.py +++ b/python/datafusion/tests/test_expr.py @@ -169,3 +169,26 @@ def traverse_logical_plan(plan): == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' ) assert not variant.negated() + + +def test_expr_getitem() -> None: + ctx = SessionContext() + data = { + "array_values": [[1, 2, 3], [4, 5], [6], []], + "struct_values": [ + {"name": "Alice", "age": 15}, + {"name": "Bob", "age": 14}, + {"name": "Charlie", "age": 13}, + {"name": None, "age": 12}, + ], + } + df = ctx.from_pydict(data, name="table1") + + names = df.select(col("struct_values")["name"].alias("name")).collect() + names = [r.as_py() for rs in names for r in rs["name"]] + + array_values = df.select(col("array_values")[1].alias("value")).collect() + array_values = [r.as_py() for rs in array_values for r in rs["value"]] + + assert names == ["Alice", "Bob", "Charlie", None] + assert array_values == [2, 5, None, None] diff --git a/src/dataframe.rs b/src/dataframe.rs index d7abab400..3fb8b2292 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -23,6 +23,7 @@ use arrow::compute::can_cast_types; use arrow::error::ArrowError; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::util::display::{ArrayFormatter, FormatOptions}; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; @@ -95,6 +96,51 @@ impl PyDataFrame { } } + fn _repr_html_(&self, py: Python) -> PyResult { + let mut html_str = "\n".to_string(); + + let df = self.df.as_ref().clone().limit(0, Some(10))?; + let batches = wait_for_future(py, df.collect())?; + + if batches.is_empty() { + html_str.push_str("
\n"); + return Ok(html_str); + } + + let schema = batches[0].schema(); + + let mut header = Vec::new(); + for field in schema.fields() { + header.push(format!("{}", field.name())); + } + let header_str = header.join(""); + html_str.push_str(&format!("{}\n", header_str)); + + for batch in batches { + let formatters = batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) + .map(|c| { + c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) + }) + .collect::, _>>()?; + + for row in 0..batch.num_rows() { + let mut cells = Vec::new(); + for formatter in &formatters { + cells.push(format!("{}", formatter.value(row))); + } + let row_str = cells.join(""); + html_str.push_str(&format!("{}\n", row_str)); + } + } + + html_str.push_str("\n"); + + Ok(html_str) + } + /// Calculate summary statistics for a DataFrame fn describe(&self, py: Python) -> PyResult { let df = self.df.as_ref().clone();