|
6 | 6 | - Python binding tests |
7 | 7 | - Functional SQL tests |
8 | 8 | - Constraint tests (expected failures) |
| 9 | +- GPU acceleration tests |
9 | 10 | """ |
10 | 11 |
|
11 | 12 | import angreal |
12 | | -from utils import run_make, ensure_extension_built |
| 13 | +import subprocess |
| 14 | +import os |
| 15 | +from utils import run_make, ensure_extension_built, get_project_root |
13 | 16 |
|
14 | 17 | test = angreal.command_group(name="test", about="Run GraphQLite tests") |
15 | 18 |
|
@@ -370,3 +373,146 @@ def test_all(verbose: bool = False) -> int: |
370 | 373 | print("All tests passed!") |
371 | 374 | print("="*50) |
372 | 375 | return 0 |
| 376 | + |
| 377 | + |
| 378 | +@test() |
| 379 | +@angreal.command( |
| 380 | + name="gpu", |
| 381 | + about="Run GPU acceleration tests", |
| 382 | + tool=angreal.ToolDescription( |
| 383 | + """ |
| 384 | +Run GPU-specific tests for the wgpu-based acceleration. |
| 385 | +
|
| 386 | +## What this tests |
| 387 | +1. Rust GPU crate unit tests (config, cost calculations) |
| 388 | +2. GPU extension build with GPU=1 |
| 389 | +3. GPU PageRank integration test (forces GPU execution) |
| 390 | +
|
| 391 | +## When to use |
| 392 | +- After changes to src/gpu/ Rust code |
| 393 | +- Validating GPU dispatch logic |
| 394 | +- Testing GPU algorithm implementations |
| 395 | +
|
| 396 | +## Examples |
| 397 | +``` |
| 398 | +angreal test gpu |
| 399 | +angreal test gpu --verbose |
| 400 | +``` |
| 401 | +
|
| 402 | +## Prerequisites |
| 403 | +- Rust toolchain installed |
| 404 | +- GPU-capable machine (Metal on macOS, Vulkan on Linux) |
| 405 | +- wgpu dependencies available |
| 406 | +""", |
| 407 | + risk_level="safe" |
| 408 | + ) |
| 409 | +) |
| 410 | +@angreal.argument( |
| 411 | + name="verbose", |
| 412 | + long="verbose", |
| 413 | + short="v", |
| 414 | + is_flag=True, |
| 415 | + takes_value=False, |
| 416 | + help="Show verbose output" |
| 417 | +) |
| 418 | +def test_gpu(verbose: bool = False) -> int: |
| 419 | + """Run GPU acceleration tests.""" |
| 420 | + root = get_project_root() |
| 421 | + gpu_dir = os.path.join(root, "src", "gpu") |
| 422 | + |
| 423 | + # Step 1: Run Rust unit tests |
| 424 | + print("Step 1: Running Rust GPU crate tests...") |
| 425 | + cmd = ["cargo", "test"] |
| 426 | + if verbose: |
| 427 | + cmd.append("--verbose") |
| 428 | + print(f"Running: {' '.join(cmd)} in {gpu_dir}") |
| 429 | + |
| 430 | + result = subprocess.run(cmd, cwd=gpu_dir) |
| 431 | + if result.returncode != 0: |
| 432 | + print("Rust GPU tests failed!") |
| 433 | + return result.returncode |
| 434 | + print("Rust GPU tests passed!") |
| 435 | + |
| 436 | + # Step 2: Build extension with GPU=1 |
| 437 | + print("\nStep 2: Building extension with GPU=1...") |
| 438 | + result = run_make("clean", verbose=verbose) |
| 439 | + if result != 0: |
| 440 | + print("Clean failed!") |
| 441 | + return result |
| 442 | + |
| 443 | + result = run_make("extension", verbose=verbose, GPU="1") |
| 444 | + if result != 0: |
| 445 | + print("GPU extension build failed!") |
| 446 | + return result |
| 447 | + print("GPU extension built successfully!") |
| 448 | + |
| 449 | + # Step 3: Run GPU integration test |
| 450 | + print("\nStep 3: Running GPU integration test...") |
| 451 | + test_script = ''' |
| 452 | +-- GPU PageRank Integration Test |
| 453 | +-- This test forces GPU execution by using a graph that exceeds the threshold |
| 454 | +
|
| 455 | +-- Create a moderately sized graph to trigger GPU dispatch |
| 456 | +-- With threshold at 100,000 and 20 iterations, we need ~5000 nodes+edges |
| 457 | +-- For simplicity, we'll test with a smaller graph but verify GPU init works |
| 458 | +
|
| 459 | +.load build/graphqlite.dylib |
| 460 | +
|
| 461 | +-- Create test graph |
| 462 | +SELECT cypher('CREATE (a:Page {id: "A"})'); |
| 463 | +SELECT cypher('CREATE (b:Page {id: "B"})'); |
| 464 | +SELECT cypher('CREATE (c:Page {id: "C"})'); |
| 465 | +SELECT cypher('CREATE (d:Page {id: "D"})'); |
| 466 | +SELECT cypher('MATCH (a:Page {id: "A"}), (b:Page {id: "B"}) CREATE (a)-[:LINKS]->(b)'); |
| 467 | +SELECT cypher('MATCH (a:Page {id: "A"}), (c:Page {id: "C"}) CREATE (a)-[:LINKS]->(c)'); |
| 468 | +SELECT cypher('MATCH (b:Page {id: "B"}), (c:Page {id: "C"}) CREATE (b)-[:LINKS]->(c)'); |
| 469 | +SELECT cypher('MATCH (c:Page {id: "C"}), (a:Page {id: "A"}) CREATE (c)-[:LINKS]->(a)'); |
| 470 | +SELECT cypher('MATCH (d:Page {id: "D"}), (c:Page {id: "C"}) CREATE (d)-[:LINKS]->(c)'); |
| 471 | +
|
| 472 | +-- Run PageRank and verify output |
| 473 | +SELECT cypher('RETURN pageRank()'); |
| 474 | +''' |
| 475 | + cmd = ["sqlite3", ":memory:"] |
| 476 | + if verbose: |
| 477 | + print(f"Running: {' '.join(cmd)}") |
| 478 | + |
| 479 | + result = subprocess.run( |
| 480 | + cmd, |
| 481 | + input=test_script, |
| 482 | + capture_output=True, |
| 483 | + text=True, |
| 484 | + cwd=root |
| 485 | + ) |
| 486 | + |
| 487 | + if verbose: |
| 488 | + print("STDOUT:", result.stdout) |
| 489 | + print("STDERR:", result.stderr) |
| 490 | + |
| 491 | + # Check for GPU initialization |
| 492 | + if "GPU acceleration enabled" not in result.stderr: |
| 493 | + print("WARNING: GPU acceleration not detected in output") |
| 494 | + print("This may be expected if no GPU is available") |
| 495 | + |
| 496 | + # Check for valid PageRank output |
| 497 | + if '"score"' not in result.stdout: |
| 498 | + print("ERROR: PageRank did not return expected results") |
| 499 | + print("Output:", result.stdout) |
| 500 | + return 1 |
| 501 | + |
| 502 | + # Verify ranking order (C should be first - highest PageRank) |
| 503 | + if '"node_id":3' not in result.stdout: |
| 504 | + print("WARNING: Node C (id:3) expected to have highest PageRank") |
| 505 | + |
| 506 | + print("GPU integration test passed!") |
| 507 | + |
| 508 | + # Step 4: Run C unit tests with GPU build |
| 509 | + print("\nStep 4: Running C unit tests with GPU build...") |
| 510 | + result = run_make("test-unit", verbose=verbose, GPU="1") |
| 511 | + if result != 0: |
| 512 | + print("C unit tests with GPU build failed!") |
| 513 | + return result |
| 514 | + |
| 515 | + print("\n" + "="*50) |
| 516 | + print("All GPU tests passed!") |
| 517 | + print("="*50) |
| 518 | + return 0 |
0 commit comments