Skip to content

Commit a1c37b7

Browse files
ion-elgrecortyler
authored andcommitted
fix: add nullability check in deltachecker
1 parent 1083c8c commit a1c37b7

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

crates/core/src/delta_datafusion/mod.rs

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,11 +1150,12 @@ pub(crate) async fn execute_plan_to_batch(
11501150
Ok(concat_batches(&plan.schema(), data.iter())?)
11511151
}
11521152

1153-
/// Responsible for checking batches of data conform to table's invariants.
1154-
#[derive(Clone)]
1153+
/// Responsible for checking batches of data conform to table's invariants, constraints and nullability.
1154+
#[derive(Clone, Default)]
11551155
pub struct DeltaDataChecker {
11561156
constraints: Vec<Constraint>,
11571157
invariants: Vec<Invariant>,
1158+
non_nullable_columns: Vec<String>,
11581159
ctx: SessionContext,
11591160
}
11601161

@@ -1164,6 +1165,7 @@ impl DeltaDataChecker {
11641165
Self {
11651166
invariants: vec![],
11661167
constraints: vec![],
1168+
non_nullable_columns: vec![],
11671169
ctx: DeltaSessionContext::default().into(),
11681170
}
11691171
}
@@ -1173,6 +1175,7 @@ impl DeltaDataChecker {
11731175
Self {
11741176
invariants,
11751177
constraints: vec![],
1178+
non_nullable_columns: vec![],
11761179
ctx: DeltaSessionContext::default().into(),
11771180
}
11781181
}
@@ -1182,6 +1185,7 @@ impl DeltaDataChecker {
11821185
Self {
11831186
constraints,
11841187
invariants: vec![],
1188+
non_nullable_columns: vec![],
11851189
ctx: DeltaSessionContext::default().into(),
11861190
}
11871191
}
@@ -1202,9 +1206,21 @@ impl DeltaDataChecker {
12021206
pub fn new(snapshot: &DeltaTableState) -> Self {
12031207
let invariants = snapshot.schema().get_invariants().unwrap_or_default();
12041208
let constraints = snapshot.table_config().get_constraints();
1209+
let non_nullable_columns = snapshot
1210+
.schema()
1211+
.fields()
1212+
.filter_map(|f| {
1213+
if !f.is_nullable() {
1214+
Some(f.name().clone())
1215+
} else {
1216+
None
1217+
}
1218+
})
1219+
.collect_vec();
12051220
Self {
12061221
invariants,
12071222
constraints,
1223+
non_nullable_columns,
12081224
ctx: DeltaSessionContext::default().into(),
12091225
}
12101226
}
@@ -1214,10 +1230,35 @@ impl DeltaDataChecker {
12141230
/// If it does not, it will return [DeltaTableError::InvalidData] with a list
12151231
/// of values that violated each invariant.
12161232
pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> {
1233+
self.check_nullability(record_batch)?;
12171234
self.enforce_checks(record_batch, &self.invariants).await?;
12181235
self.enforce_checks(record_batch, &self.constraints).await
12191236
}
12201237

1238+
/// Return true if all the nullability checks are valid
1239+
fn check_nullability(&self, record_batch: &RecordBatch) -> Result<bool, DeltaTableError> {
1240+
let mut violations = Vec::new();
1241+
for col in self.non_nullable_columns.iter() {
1242+
if let Some(arr) = record_batch.column_by_name(col) {
1243+
if arr.null_count() > 0 {
1244+
violations.push(format!(
1245+
"Non-nullable column violation for {col}, found {} null values",
1246+
arr.null_count()
1247+
));
1248+
}
1249+
} else {
1250+
violations.push(format!(
1251+
"Non-nullable column violation for {col}, not found in batch!"
1252+
));
1253+
}
1254+
}
1255+
if !violations.is_empty() {
1256+
Err(DeltaTableError::InvalidData { violations })
1257+
} else {
1258+
Ok(true)
1259+
}
1260+
}
1261+
12211262
async fn enforce_checks<C: DataCheck>(
12221263
&self,
12231264
record_batch: &RecordBatch,
@@ -2598,4 +2639,38 @@ mod tests {
25982639

25992640
assert_eq!(actual.len(), 0);
26002641
}
2642+
2643+
#[tokio::test]
2644+
async fn test_check_nullability() -> DeltaResult<()> {
2645+
use arrow::array::StringArray;
2646+
2647+
let data_checker = DeltaDataChecker {
2648+
non_nullable_columns: vec!["zed".to_string(), "yap".to_string()],
2649+
..Default::default()
2650+
};
2651+
2652+
let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
2653+
let nulls: Arc<dyn Array> = Arc::new(StringArray::new_null(1));
2654+
let batch = RecordBatch::try_from_iter(vec![("a", arr), ("zed", nulls)]).unwrap();
2655+
2656+
let result = data_checker.check_nullability(&batch);
2657+
assert!(
2658+
result.is_err(),
2659+
"The result should have errored! {result:?}"
2660+
);
2661+
2662+
let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
2663+
let batch = RecordBatch::try_from_iter(vec![("zed", arr)]).unwrap();
2664+
let result = data_checker.check_nullability(&batch);
2665+
assert!(
2666+
result.is_err(),
2667+
"The result should have errored! {result:?}"
2668+
);
2669+
2670+
let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
2671+
let batch = RecordBatch::try_from_iter(vec![("zed", arr.clone()), ("yap", arr)]).unwrap();
2672+
let _ = data_checker.check_nullability(&batch)?;
2673+
2674+
Ok(())
2675+
}
26012676
}

python/tests/test_merge.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from deltalake import DeltaTable, write_deltalake
9+
from deltalake.exceptions import DeltaProtocolError
910
from deltalake.table import CommitProperties
1011

1112

@@ -1080,3 +1081,42 @@ def test_cdc_merge_planning_union_2908(tmp_path):
10801081
assert last_action["operation"] == "MERGE"
10811082
assert dt.version() == 1
10821083
assert os.path.exists(cdc_path), "_change_data doesn't exist"
1084+
1085+
1086+
@pytest.mark.pandas
1087+
def test_merge_non_nullable(tmp_path):
1088+
import re
1089+
1090+
import pandas as pd
1091+
1092+
from deltalake.schema import Field, PrimitiveType, Schema
1093+
1094+
schema = Schema(
1095+
[
1096+
Field("id", PrimitiveType("integer"), nullable=False),
1097+
Field("bool", PrimitiveType("boolean"), nullable=False),
1098+
]
1099+
)
1100+
1101+
dt = DeltaTable.create(tmp_path, schema=schema)
1102+
df = pd.DataFrame(
1103+
columns=["id", "bool"],
1104+
data=[
1105+
[1, True],
1106+
[2, None],
1107+
[3, False],
1108+
],
1109+
)
1110+
1111+
with pytest.raises(
1112+
DeltaProtocolError,
1113+
match=re.escape(
1114+
'Invariant violations: ["Non-nullable column violation for bool, found 1 null values"]'
1115+
),
1116+
):
1117+
dt.merge(
1118+
source=df,
1119+
source_alias="s",
1120+
target_alias="t",
1121+
predicate="s.id = t.id",
1122+
).when_matched_update_all().when_not_matched_insert_all().execute()

0 commit comments

Comments
 (0)