Skip to content

Commit f877e5e

Browse files
committed
Refactor V2StreamingReadTest: add withSQLConf helper, inline fileFilter, use assertDataEquals
1 parent 0ced3e4 commit f877e5e

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

spark/v2/src/test/java/io/delta/spark/internal/v2/V2StreamingReadTest.java

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@
2222
import java.util.ArrayList;
2323
import java.util.Arrays;
2424
import java.util.List;
25+
import java.util.stream.Collectors;
26+
import java.util.stream.LongStream;
2527
import org.apache.spark.sql.*;
2628
import org.apache.spark.sql.catalyst.expressions.Expression;
2729
import org.apache.spark.sql.catalyst.expressions.Literal$;
2830
import org.apache.spark.sql.delta.DeltaLog;
29-
import org.apache.spark.sql.delta.actions.AddFile;
3031
import org.apache.spark.sql.delta.stats.StatisticsCollection;
3132
import org.junit.jupiter.api.*;
3233
import org.junit.jupiter.api.io.TempDir;
33-
import scala.Function1;
3434
import scala.Option;
3535
import scala.collection.JavaConverters;
36-
import scala.runtime.AbstractFunction1;
3736

3837
/** Tests for V2 streaming read operations. */
3938
public class V2StreamingReadTest extends V2TestBase {
@@ -121,44 +120,41 @@ public void testStreamingReadAfterStatsRecompute(@TempDir File deltaTablePath) t
121120
String tablePath = deltaTablePath.getAbsolutePath();
122121

123122
// Write data with stats collection disabled - files will have no stats
124-
spark.conf().set("spark.databricks.delta.stats.collect", "false");
125-
try {
126-
spark
127-
.range(10)
128-
.selectExpr("id", "cast(id as string) as value")
129-
.write()
130-
.format("delta")
131-
.save(tablePath);
132-
} finally {
133-
spark.conf().set("spark.databricks.delta.stats.collect", "true");
134-
}
123+
withSQLConf(
124+
"spark.databricks.delta.stats.collect",
125+
"false",
126+
() ->
127+
spark
128+
.range(10)
129+
.selectExpr("id", "cast(id as string) as value")
130+
.write()
131+
.format("delta")
132+
.save(tablePath));
135133

136134
// Recompute statistics - this re-adds files with updated stats (dataChange=false),
137135
// creating duplicate AddFile entries in the log that must be filtered by selection vector
138136
DeltaLog deltaLog = DeltaLog.forTable(spark, tablePath);
139-
scala.collection.immutable.Seq<Expression> predicates =
137+
StatisticsCollection.recompute(
138+
spark,
139+
deltaLog,
140+
Option.empty(),
140141
JavaConverters.<Expression>asScalaBuffer(
141142
new ArrayList<>(List.of((Expression) Literal$.MODULE$.apply(true))))
142-
.toList();
143-
Function1<AddFile, Object> fileFilter =
144-
new AbstractFunction1<AddFile, Object>() {
145-
@Override
146-
public Object apply(AddFile af) {
147-
return (Object) Boolean.TRUE;
148-
}
149-
};
150-
StatisticsCollection.recompute(spark, deltaLog, Option.empty(), predicates, fileFilter);
143+
.toList(),
144+
af -> (Object) Boolean.TRUE);
151145

152146
// Stream via V2 - should see each row exactly once, not duplicated
153147
String dsv2TableRef = str("dsv2.delta.`%s`", tablePath);
154148
Dataset<Row> streamingDF = spark.readStream().table(dsv2TableRef);
155149

156150
List<Row> actualRows = processStreamingQuery(streamingDF, "test_stats_recompute");
157151

158-
assertEquals(
159-
10,
160-
actualRows.size(),
161-
"Stats recompute should not cause duplicate rows in streaming read. Got: " + actualRows);
152+
List<Row> expectedRows =
153+
LongStream.range(0, 10)
154+
.mapToObj(i -> RowFactory.create(i, String.valueOf(i)))
155+
.collect(Collectors.toList());
156+
157+
assertDataEquals(actualRows, expectedRows);
162158
}
163159

164160
/**

spark/v2/src/test/java/io/delta/spark/internal/v2/V2TestBase.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,24 @@ protected List<Row> processStreamingQuery(Dataset<Row> streamingDF, String query
147147
}
148148
}
149149

150+
/**
151+
* Runs the given action with a Spark SQL configuration temporarily set, then restores the
152+
* original value afterwards (similar to Scala's {@code withSQLConf}).
153+
*/
154+
protected void withSQLConf(String key, String value, Runnable action) {
155+
scala.Option<String> original = spark.conf().getOption(key);
156+
spark.conf().set(key, value);
157+
try {
158+
action.run();
159+
} finally {
160+
if (original.isDefined()) {
161+
spark.conf().set(key, original.get());
162+
} else {
163+
spark.conf().unset(key);
164+
}
165+
}
166+
}
167+
150168
/**
151169
* Asserts that rows equal the expected rows (order-independent).
152170
*

0 commit comments

Comments
 (0)