Skip to content

Commit 5a3cd8c

Browse files
committed
amend def fill_null to invoke PyDataFrame's fill_null
- Implemented `fill_null` method in `dataframe.rs` to allow filling null values with a specified value for specific columns or all columns. - Added a helper function `python_value_to_scalar_value` to convert Python values to DataFusion ScalarValues, supporting various types including integers, floats, booleans, strings, and timestamps. - Updated the `count` method in `PyDataFrame` to maintain functionality.
1 parent 8b51ee9 commit 5a3cd8c

File tree

2 files changed

+64
-88
lines changed

2 files changed

+64
-88
lines changed

python/datafusion/dataframe.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -894,92 +894,5 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> "DataFrame":
894894
- For columns where casting fails, the original column is kept unchanged
895895
- For columns not in subset, the original column is kept unchanged
896896
"""
897-
# Get columns to process
898-
if subset is None:
899-
subset = self.schema().names
900-
else:
901-
schema_cols = self.schema().names
902-
for col in subset:
903-
if col not in schema_cols:
904-
raise ValueError(f"Column '{col}' not found in DataFrame")
905-
906-
# Build expressions for select
907-
exprs = []
908-
for col_name in self.schema().names:
909-
if col_name in subset:
910-
# Get column type
911-
col_type = self.schema().field(col_name).type
912-
913-
try:
914-
# Try casting value to column type
915-
typed_value = pa.scalar(value, type=col_type)
916-
literal_expr = f.Expr.literal(typed_value)
917-
918-
# Build coalesce expression
919-
expr = f.coalesce(f.col(col_name), literal_expr)
920-
exprs.append(expr.alias(col_name))
921-
922-
except (pa.ArrowTypeError, pa.ArrowInvalid):
923-
# If cast fails, keep original column
924-
exprs.append(f.col(col_name))
925-
else:
926-
# Keep columns not in subset unchanged
927-
exprs.append(f.col(col_name))
928-
929-
return self.select(*exprs)
930-
931-
def fill_nan(
932-
self, value: float | int, subset: list[str] | None = None
933-
) -> "DataFrame":
934-
"""Fill NaN values in specified numeric columns with a value.
935-
936-
Args:
937-
value: Numeric value to replace NaN values with.
938-
subset: Optional list of column names to fill. If None, fills all numeric
939-
columns.
940-
941-
Returns:
942-
DataFrame with NaN values replaced in numeric columns.
943897

944-
Examples:
945-
>>> df = df.fill_nan(0) # Fill all NaNs with 0 in numeric columns
946-
>>> # Fill NaNs in specific numeric columns
947-
>>> df = df.fill_nan(99.9, subset=["price", "score"])
948-
949-
Notes:
950-
- Only fills NaN values in numeric columns (float32, float64)
951-
- Non-numeric columns are kept unchanged
952-
- For columns not in subset, the original column is kept unchanged
953-
- Value must be numeric (int or float)
954-
"""
955-
if not isinstance(value, (int, float)):
956-
raise ValueError("Value must be numeric (int or float)")
957-
958-
# Get columns to process
959-
if subset is None:
960-
# Only get numeric columns if no subset specified
961-
subset = [
962-
field.name
963-
for field in self.schema()
964-
if pa.types.is_floating(field.type)
965-
]
966-
else:
967-
schema_cols = self.schema().names
968-
for col in subset:
969-
if col not in schema_cols:
970-
raise ValueError(f"Column '{col}' not found in DataFrame")
971-
if not pa.types.is_floating(self.schema().field(col).type):
972-
raise ValueError(f"Column '{col}' is not a numeric column")
973-
974-
# Build expressions for select
975-
exprs = []
976-
for col_name in self.schema().names:
977-
if col_name in subset:
978-
# Use nanvl function to replace NaN values
979-
expr = f.nanvl(f.col(col_name), f.lit(value))
980-
exprs.append(expr.alias(col_name))
981-
else:
982-
# Keep columns not in subset unchanged
983-
exprs.append(f.col(col_name))
984-
985-
return self.select(*exprs)
898+
return DataFrame(self.df.fill_null(value, subset))

src/dataframe.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,25 @@ impl PyDataFrame {
797797
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
798798
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
799799
}
800+
801+
/// Fill null values with a specified value for specific columns
802+
#[pyo3(signature = (value, columns=None))]
803+
fn fill_null(
804+
&self,
805+
value: PyObject,
806+
columns: Option<Vec<PyBackedStr>>,
807+
py: Python,
808+
) -> PyDataFusionResult<Self> {
809+
let scalar_value = python_value_to_scalar_value(&value, py)?;
810+
811+
let cols = match columns {
812+
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
813+
None => Vec::new(), // Empty vector means fill null for all columns
814+
};
815+
816+
let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
817+
Ok(Self::new(df))
818+
}
800819
}
801820

802821
/// Print DataFrame
@@ -951,3 +970,47 @@ async fn collect_record_batches_to_display(
951970

952971
Ok((record_batches, has_more))
953972
}
973+
974+
/// Convert a Python value to a DataFusion ScalarValue
975+
fn python_value_to_scalar_value(value: &PyObject, py: Python) -> PyDataFusionResult<ScalarValue> {
976+
if value.is_none(py) {
977+
return Err(PyDataFusionError::Common(
978+
"Cannot use None as fill value".to_string(),
979+
));
980+
} else if let Ok(val) = value.extract::<i64>(py) {
981+
return Ok(ScalarValue::Int64(Some(val)));
982+
} else if let Ok(val) = value.extract::<f64>(py) {
983+
return Ok(ScalarValue::Float64(Some(val)));
984+
} else if let Ok(val) = value.extract::<bool>(py) {
985+
return Ok(ScalarValue::Boolean(Some(val)));
986+
} else if let Ok(val) = value.extract::<String>(py) {
987+
return Ok(ScalarValue::Utf8(Some(val)));
988+
} else if let Ok(dt) = py
989+
.import("datetime")
990+
.and_then(|m| m.getattr("datetime"))
991+
.and_then(|dt| value.is_instance(dt))
992+
{
993+
if value.is_instance_of::<pyo3::types::PyDateTime>(py) {
994+
let naive_dt = value.extract::<chrono::NaiveDateTime>(py)?;
995+
return Ok(ScalarValue::TimestampNanosecond(
996+
Some(naive_dt.timestamp_nanos()),
997+
None,
998+
));
999+
} else {
1000+
return Err(PyDataFusionError::Common(
1001+
"Unsupported datetime type".to_string(),
1002+
));
1003+
}
1004+
}
1005+
1006+
// Try to convert to string as fallback
1007+
match value.str(py) {
1008+
Ok(py_str) => {
1009+
let s = py_str.to_string()?;
1010+
Ok(ScalarValue::Utf8(Some(s)))
1011+
}
1012+
Err(_) => Err(PyDataFusionError::Common(
1013+
"Unsupported Python type for fill_null".to_string(),
1014+
)),
1015+
}
1016+
}

0 commit comments

Comments
 (0)