|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.streaming.test |
19 | 19 |
|
20 | | -import org.scalatest.Tag |
| 20 | +import scala.concurrent.duration._ |
| 21 | + |
| 22 | +import org.apache.hadoop.fs.Path |
| 23 | +import org.mockito.ArgumentMatchers.{any, eq => meq} |
| 24 | +import org.mockito.Mockito._ |
| 25 | +import org.scalatest.{BeforeAndAfterEach, Tag} |
21 | 26 |
|
22 | 27 | import org.apache.spark.sql._ |
23 | 28 | import org.apache.spark.sql.internal.SQLConf |
24 | 29 | import org.apache.spark.sql.streaming.StreamTest |
| 30 | +import org.apache.spark.sql.streaming.Trigger._ |
| 31 | +import org.apache.spark.util.Utils |
25 | 32 |
|
26 | 33 | /** |
27 | 34 | * Test suite for streaming source naming and validation. |
28 | 35 | * Tests cover the naming API, validation rules, and resolution pipeline. |
29 | 36 | */ |
30 | | -class StreamingQueryEvolutionSuite extends StreamTest { |
| 37 | +class StreamingQueryEvolutionSuite extends StreamTest with BeforeAndAfterEach { |
| 38 | + import testImplicits._ |
| 39 | + |
| 40 | + private def newMetadataDir = |
| 41 | + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath |
| 42 | + |
| 43 | + override def afterEach(): Unit = { |
| 44 | + spark.streams.active.foreach(_.stop()) |
| 45 | + super.afterEach() |
| 46 | + } |
| 47 | + |
| 48 | + /** |
| 49 | + * Helper to verify that a source was created with the expected metadata path. |
| 50 | + * @param checkpointLocation the checkpoint location path |
| 51 | + * @param sourcePath the expected source path (e.g., "source1" or "0") |
| 52 | + * @param mode mockito verification mode (default: times(1)) |
| 53 | + */ |
| 54 | + private def verifySourcePath( |
| 55 | + checkpointLocation: Path, |
| 56 | + sourcePath: String, |
| 57 | + mode: org.mockito.verification.VerificationMode = times(1)): Unit = { |
| 58 | + verify(LastOptions.mockStreamSourceProvider, mode).createSource( |
| 59 | + any(), |
| 60 | + meq(s"${new Path(makeQualifiedPath( |
| 61 | + checkpointLocation.toString)).toString}/sources/$sourcePath"), |
| 62 | + meq(None), |
| 63 | + meq("org.apache.spark.sql.streaming.test"), |
| 64 | + meq(Map.empty)) |
| 65 | + } |
31 | 66 |
|
32 | 67 | // ==================== |
33 | 68 | // Name Validation Tests |
@@ -159,16 +194,258 @@ class StreamingQueryEvolutionSuite extends StreamTest { |
159 | 194 | assert(union.isStreaming, "Union should be streaming") |
160 | 195 | } |
161 | 196 |
|
| 197 | + test("without enforcement - naming sources throws error") { |
| 198 | + withSQLConf(SQLConf.ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false") { |
| 199 | + checkError( |
| 200 | + exception = intercept[AnalysisException] { |
| 201 | + spark.readStream |
| 202 | + .format("org.apache.spark.sql.streaming.test") |
| 203 | + .name("mySource") |
| 204 | + .load() |
| 205 | + }, |
| 206 | + condition = "STREAMING_QUERY_EVOLUTION_ERROR.SOURCE_NAMING_NOT_SUPPORTED", |
| 207 | + parameters = Map("name" -> "mySource")) |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + // ======================= |
| 212 | + // Metadata Path Tests |
| 213 | + // ======================= |
| 214 | + |
| 215 | + testWithSourceEvolution("named sources - metadata path uses source name") { |
| 216 | + LastOptions.clear() |
| 217 | + |
| 218 | + val checkpointLocation = new Path(newMetadataDir) |
| 219 | + |
| 220 | + val df1 = spark.readStream |
| 221 | + .format("org.apache.spark.sql.streaming.test") |
| 222 | + .name("source1") |
| 223 | + .load() |
| 224 | + |
| 225 | + val df2 = spark.readStream |
| 226 | + .format("org.apache.spark.sql.streaming.test") |
| 227 | + .name("source2") |
| 228 | + .load() |
| 229 | + |
| 230 | + val q = df1.union(df2).writeStream |
| 231 | + .format("org.apache.spark.sql.streaming.test") |
| 232 | + .option("checkpointLocation", checkpointLocation.toString) |
| 233 | + .trigger(ProcessingTime(10.seconds)) |
| 234 | + .start() |
| 235 | + q.processAllAvailable() |
| 236 | + q.stop() |
| 237 | + |
| 238 | + verifySourcePath(checkpointLocation, "source1") |
| 239 | + verifySourcePath(checkpointLocation, "source2") |
| 240 | + } |
| 241 | + |
| 242 | + test("unnamed sources use positional IDs for metadata path") { |
| 243 | + withSQLConf(SQLConf.ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false") { |
| 244 | + LastOptions.clear() |
| 245 | + |
| 246 | + val checkpointLocation = new Path(newMetadataDir) |
| 247 | + |
| 248 | + val df1 = spark.readStream |
| 249 | + .format("org.apache.spark.sql.streaming.test") |
| 250 | + .load() |
| 251 | + |
| 252 | + val df2 = spark.readStream |
| 253 | + .format("org.apache.spark.sql.streaming.test") |
| 254 | + .load() |
| 255 | + |
| 256 | + val q = df1.union(df2).writeStream |
| 257 | + .format("org.apache.spark.sql.streaming.test") |
| 258 | + .option("checkpointLocation", checkpointLocation.toString) |
| 259 | + .trigger(ProcessingTime(10.seconds)) |
| 260 | + .start() |
| 261 | + q.processAllAvailable() |
| 262 | + q.stop() |
| 263 | + |
| 264 | + // Without naming, sources get sequential IDs (Unassigned -> 0, 1, ...) |
| 265 | + verifySourcePath(checkpointLocation, "0") |
| 266 | + verifySourcePath(checkpointLocation, "1") |
| 267 | + } |
| 268 | + } |
| 269 | + |
| 270 | + // ======================== |
| 271 | + // Source Evolution Tests |
| 272 | + // ======================== |
| 273 | + |
| 274 | + testWithSourceEvolution("source evolution - reorder sources with named sources") { |
| 275 | + LastOptions.clear() |
| 276 | + |
| 277 | + val checkpointLocation = new Path(newMetadataDir) |
| 278 | + |
| 279 | + // First query: source1 then source2 |
| 280 | + val df1a = spark.readStream |
| 281 | + .format("org.apache.spark.sql.streaming.test") |
| 282 | + .name("source1") |
| 283 | + .load() |
| 284 | + |
| 285 | + val df2a = spark.readStream |
| 286 | + .format("org.apache.spark.sql.streaming.test") |
| 287 | + .name("source2") |
| 288 | + .load() |
| 289 | + |
| 290 | + val q1 = df1a.union(df2a).writeStream |
| 291 | + .format("org.apache.spark.sql.streaming.test") |
| 292 | + .option("checkpointLocation", checkpointLocation.toString) |
| 293 | + .trigger(ProcessingTime(10.seconds)) |
| 294 | + .start() |
| 295 | + q1.processAllAvailable() |
| 296 | + q1.stop() |
| 297 | + |
| 298 | + LastOptions.clear() |
| 299 | + |
| 300 | + // Second query: source2 then source1 (reordered) - should still work |
| 301 | + val df1b = spark.readStream |
| 302 | + .format("org.apache.spark.sql.streaming.test") |
| 303 | + .name("source1") |
| 304 | + .load() |
| 305 | + |
| 306 | + val df2b = spark.readStream |
| 307 | + .format("org.apache.spark.sql.streaming.test") |
| 308 | + .name("source2") |
| 309 | + .load() |
| 310 | + |
| 311 | + val q2 = df2b.union(df1b).writeStream // Note: reversed order |
| 312 | + .format("org.apache.spark.sql.streaming.test") |
| 313 | + .option("checkpointLocation", checkpointLocation.toString) |
| 314 | + .trigger(ProcessingTime(10.seconds)) |
| 315 | + .start() |
| 316 | + q2.processAllAvailable() |
| 317 | + q2.stop() |
| 318 | + |
| 319 | + // Both sources should still use their named paths |
| 320 | + verifySourcePath(checkpointLocation, "source1", atLeastOnce()) |
| 321 | + verifySourcePath(checkpointLocation, "source2", atLeastOnce()) |
| 322 | + } |
| 323 | + |
| 324 | + testWithSourceEvolution("source evolution - add new source with named sources") { |
| 325 | + LastOptions.clear() |
| 326 | + |
| 327 | + val checkpointLocation = new Path(newMetadataDir) |
| 328 | + |
| 329 | + // First query: only source1 |
| 330 | + val df1 = spark.readStream |
| 331 | + .format("org.apache.spark.sql.streaming.test") |
| 332 | + .name("source1") |
| 333 | + .load() |
| 334 | + |
| 335 | + val q1 = df1.writeStream |
| 336 | + .format("org.apache.spark.sql.streaming.test") |
| 337 | + .option("checkpointLocation", checkpointLocation.toString) |
| 338 | + .trigger(ProcessingTime(10.seconds)) |
| 339 | + .start() |
| 340 | + q1.processAllAvailable() |
| 341 | + q1.stop() |
| 342 | + |
| 343 | + LastOptions.clear() |
| 344 | + |
| 345 | + // Second query: add source2 |
| 346 | + val df1b = spark.readStream |
| 347 | + .format("org.apache.spark.sql.streaming.test") |
| 348 | + .name("source1") |
| 349 | + .load() |
| 350 | + |
| 351 | + val df2 = spark.readStream |
| 352 | + .format("org.apache.spark.sql.streaming.test") |
| 353 | + .name("source2") |
| 354 | + .load() |
| 355 | + |
| 356 | + val q2 = df1b.union(df2).writeStream |
| 357 | + .format("org.apache.spark.sql.streaming.test") |
| 358 | + .option("checkpointLocation", checkpointLocation.toString) |
| 359 | + .trigger(ProcessingTime(10.seconds)) |
| 360 | + .start() |
| 361 | + q2.processAllAvailable() |
| 362 | + q2.stop() |
| 363 | + |
| 364 | + // Both sources should have been created |
| 365 | + verifySourcePath(checkpointLocation, "source1", atLeastOnce()) |
| 366 | + verifySourcePath(checkpointLocation, "source2") |
| 367 | + } |
| 368 | + |
| 369 | + testWithSourceEvolution("named sources enforcement uses V2 offset log format") { |
| 370 | + LastOptions.clear() |
| 371 | + |
| 372 | + val checkpointLocation = new Path(newMetadataDir) |
| 373 | + |
| 374 | + val df1 = spark.readStream |
| 375 | + .format("org.apache.spark.sql.streaming.test") |
| 376 | + .name("source1") |
| 377 | + .load() |
| 378 | + |
| 379 | + val df2 = spark.readStream |
| 380 | + .format("org.apache.spark.sql.streaming.test") |
| 381 | + .name("source2") |
| 382 | + .load() |
| 383 | + |
| 384 | + val q = df1.union(df2).writeStream |
| 385 | + .format("org.apache.spark.sql.streaming.test") |
| 386 | + .option("checkpointLocation", checkpointLocation.toString) |
| 387 | + .trigger(ProcessingTime(10.seconds)) |
| 388 | + .start() |
| 389 | + q.processAllAvailable() |
| 390 | + q.stop() |
| 391 | + |
| 392 | + import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeqLog} |
| 393 | + val offsetLog = new OffsetSeqLog(spark, |
| 394 | + makeQualifiedPath(checkpointLocation.toString).toString + "/offsets") |
| 395 | + val offsetSeq = offsetLog.get(0) |
| 396 | + assert(offsetSeq.isDefined, "Offset log should have batch 0") |
| 397 | + assert(offsetSeq.get.isInstanceOf[OffsetMap], |
| 398 | + s"Expected OffsetMap but got ${offsetSeq.get.getClass.getSimpleName}") |
| 399 | + } |
| 400 | + |
| 401 | + testWithSourceEvolution("names preserved through union operations") { |
| 402 | + LastOptions.clear() |
| 403 | + |
| 404 | + val checkpointLocation = new Path(newMetadataDir) |
| 405 | + |
| 406 | + val df1 = spark.readStream |
| 407 | + .format("org.apache.spark.sql.streaming.test") |
| 408 | + .name("alpha") |
| 409 | + .load() |
| 410 | + |
| 411 | + val df2 = spark.readStream |
| 412 | + .format("org.apache.spark.sql.streaming.test") |
| 413 | + .name("beta") |
| 414 | + .load() |
| 415 | + |
| 416 | + val df3 = spark.readStream |
| 417 | + .format("org.apache.spark.sql.streaming.test") |
| 418 | + .name("gamma") |
| 419 | + .load() |
| 420 | + |
| 421 | + // Complex union: (alpha union beta) union gamma |
| 422 | + val q = df1.union(df2).union(df3).writeStream |
| 423 | + .format("org.apache.spark.sql.streaming.test") |
| 424 | + .option("checkpointLocation", checkpointLocation.toString) |
| 425 | + .trigger(ProcessingTime(10.seconds)) |
| 426 | + .start() |
| 427 | + q.processAllAvailable() |
| 428 | + q.stop() |
| 429 | + |
| 430 | + // All three sources should use their named paths |
| 431 | + verifySourcePath(checkpointLocation, "alpha") |
| 432 | + verifySourcePath(checkpointLocation, "beta") |
| 433 | + verifySourcePath(checkpointLocation, "gamma") |
| 434 | + } |
| 435 | + |
162 | 436 | // ============== |
163 | 437 | // Helper Methods |
164 | 438 | // ============== |
165 | 439 |
|
166 | 440 | /** |
167 | 441 | * Helper method to run tests with source evolution enabled. |
| 442 | + * Sets offset log format to V2 (OffsetMap) since named sources require it. |
168 | 443 | */ |
169 | 444 | def testWithSourceEvolution(testName: String, testTags: Tag*)(testBody: => Any): Unit = { |
170 | 445 | test(testName, testTags: _*) { |
171 | | - withSQLConf(SQLConf.ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "true") { |
| 446 | + withSQLConf( |
| 447 | + SQLConf.ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "true", |
| 448 | + SQLConf.STREAMING_OFFSET_LOG_FORMAT_VERSION.key -> "2") { |
172 | 449 | testBody |
173 | 450 | } |
174 | 451 | } |
|
0 commit comments