Skip to content

Commit b04dd7f

Browse files
committed
Update test_evolved.py
1 parent 6407304 commit b04dd7f

File tree

1 file changed

+216
-28
lines changed

1 file changed

+216
-28
lines changed

examples/mlx_spda_optimization/test_evolved.py

Lines changed: 216 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -313,43 +313,233 @@ def run_block_diagonal_tests(evolved_fn):
313313
print("Format: Test | Shape | Blocks | Sparsity | Evolved | SPDA | Speedup | Status")
314314
print("-" * 80)
315315

316-
# Block-diagonal test configurations
316+
# Block-diagonal test configurations - comprehensive coverage
317317
block_configs = [
318+
# ===== BASIC SPARSITY PROGRESSION =====
318319
{
319-
"name": "packed_2x256_sparse50",
320+
"name": "dense_2x256_sparse50",
320321
"B": 1, "H": 8, "L": 512, "D": 64,
321-
"block_sizes": [256, 256], # 50% sparse
322-
"expected_speedup": 1.2
322+
"block_sizes": [256, 256] # 50% sparse - baseline
323323
},
324324
{
325-
"name": "packed_4x128_sparse75",
325+
"name": "medium_4x128_sparse75",
326326
"B": 1, "H": 16, "L": 512, "D": 64,
327-
"block_sizes": [128, 128, 128, 128], # 75% sparse
328-
"expected_speedup": 1.5
327+
"block_sizes": [128, 128, 128, 128] # 75% sparse
329328
},
330329
{
331-
"name": "packed_8x128_sparse87",
330+
"name": "sparse_8x64_sparse87",
331+
"B": 1, "H": 16, "L": 512, "D": 64,
332+
"block_sizes": [64] * 8 # 87.5% sparse
333+
},
334+
{
335+
"name": "very_sparse_16x32_sparse93",
336+
"B": 1, "H": 16, "L": 512, "D": 64,
337+
"block_sizes": [32] * 16 # 93.75% sparse
338+
},
339+
{
340+
"name": "extreme_sparse_32x16_sparse96",
341+
"B": 1, "H": 16, "L": 512, "D": 64,
342+
"block_sizes": [16] * 32 # 96.875% sparse
343+
},
344+
345+
# ===== DIFFERENT SEQUENCE LENGTHS =====
346+
{
347+
"name": "small_seq_4x32_sparse75",
348+
"B": 1, "H": 8, "L": 128, "D": 64,
349+
"block_sizes": [32, 32, 32, 32] # Small sequences
350+
},
351+
{
352+
"name": "medium_seq_8x64_sparse87",
353+
"B": 1, "H": 16, "L": 512, "D": 64,
354+
"block_sizes": [64] * 8 # Medium sequences
355+
},
356+
{
357+
"name": "large_seq_8x128_sparse87",
358+
"B": 1, "H": 16, "L": 1024, "D": 64,
359+
"block_sizes": [128] * 8 # Large sequences
360+
},
361+
{
362+
"name": "huge_seq_16x128_sparse93",
363+
"B": 1, "H": 32, "L": 2048, "D": 64,
364+
"block_sizes": [128] * 16 # Very large sequences
365+
},
366+
{
367+
"name": "giant_seq_32x64_sparse96",
368+
"B": 1, "H": 32, "L": 2048, "D": 64,
369+
"block_sizes": [64] * 32 # Extreme sequences
370+
},
371+
372+
# ===== DIFFERENT HEAD DIMENSIONS =====
373+
{
374+
"name": "head64_8x64_sparse87",
375+
"B": 1, "H": 16, "L": 512, "D": 64,
376+
"block_sizes": [64] * 8 # Standard head dim
377+
},
378+
{
379+
"name": "head80_8x64_sparse87",
380+
"B": 1, "H": 16, "L": 512, "D": 80,
381+
"block_sizes": [64] * 8 # PaLM head dim
382+
},
383+
{
384+
"name": "head128_8x64_sparse87",
385+
"B": 1, "H": 16, "L": 512, "D": 128,
386+
"block_sizes": [64] * 8 # Large head dim
387+
},
388+
{
389+
"name": "head32_8x64_sparse87",
390+
"B": 1, "H": 16, "L": 512, "D": 32,
391+
"block_sizes": [64] * 8 # Small head dim
392+
},
393+
394+
# ===== MIXED BLOCK SIZES =====
395+
{
396+
"name": "mixed_sizes_pyramid",
397+
"B": 1, "H": 16, "L": 1024, "D": 64,
398+
"block_sizes": [512, 256, 128, 64, 32, 16, 8, 8] # Pyramid pattern
399+
},
400+
{
401+
"name": "mixed_sizes_alternating",
332402
"B": 1, "H": 16, "L": 1024, "D": 64,
333-
"block_sizes": [128] * 8, # 87.5% sparse
334-
"expected_speedup": 2.0
403+
"block_sizes": [128, 64, 128, 64, 128, 64, 128, 64, 128, 64] # Alternating
404+
},
405+
{
406+
"name": "mixed_sizes_bimodal",
407+
"B": 1, "H": 16, "L": 1024, "D": 64,
408+
"block_sizes": [256, 256, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32] # Two large + many small
409+
},
410+
411+
# ===== BATCH SIZE VARIATIONS =====
412+
{
413+
"name": "batch1_8x64_sparse87",
414+
"B": 1, "H": 16, "L": 512, "D": 64,
415+
"block_sizes": [64] * 8 # Single batch
416+
},
417+
{
418+
"name": "batch2_8x64_sparse87",
419+
"B": 2, "H": 16, "L": 512, "D": 64,
420+
"block_sizes": [64] * 8 # Small batch
421+
},
422+
{
423+
"name": "batch4_8x64_sparse87",
424+
"B": 4, "H": 16, "L": 512, "D": 64,
425+
"block_sizes": [64] * 8 # Medium batch
426+
},
427+
{
428+
"name": "batch8_8x64_sparse87",
429+
"B": 8, "H": 16, "L": 512, "D": 64,
430+
"block_sizes": [64] * 8 # Large batch
431+
},
432+
433+
# ===== HEAD COUNT VARIATIONS =====
434+
{
435+
"name": "heads4_8x64_sparse87",
436+
"B": 1, "H": 4, "L": 512, "D": 64,
437+
"block_sizes": [64] * 8 # Few heads
438+
},
439+
{
440+
"name": "heads16_8x64_sparse87",
441+
"B": 1, "H": 16, "L": 512, "D": 64,
442+
"block_sizes": [64] * 8 # Standard heads
443+
},
444+
{
445+
"name": "heads32_8x64_sparse87",
446+
"B": 1, "H": 32, "L": 512, "D": 64,
447+
"block_sizes": [64] * 8 # Many heads
448+
},
449+
{
450+
"name": "heads64_8x64_sparse87",
451+
"B": 1, "H": 64, "L": 512, "D": 64,
452+
"block_sizes": [64] * 8 # Very many heads
453+
},
454+
455+
# ===== TINY BLOCKS (EXTREME SPARSITY) =====
456+
{
457+
"name": "tiny_blocks_64x8_sparse98",
458+
"B": 1, "H": 16, "L": 512, "D": 64,
459+
"block_sizes": [8] * 64 # 98.4% sparse
460+
},
461+
{
462+
"name": "tiny_blocks_128x4_sparse99",
463+
"B": 1, "H": 16, "L": 512, "D": 64,
464+
"block_sizes": [4] * 128 # 99.2% sparse
465+
},
466+
467+
# ===== LARGE BLOCKS (DENSE PATTERNS) =====
468+
{
469+
"name": "large_blocks_2x256_sparse50",
470+
"B": 1, "H": 8, "L": 512, "D": 64,
471+
"block_sizes": [256, 256] # Only 50% sparse
335472
},
336473
{
337-
"name": "packed_16x64_sparse93",
338-
"B": 1, "H": 16, "L": 1024, "D": 128,
339-
"block_sizes": [64] * 16, # 93.75% sparse
340-
"expected_speedup": 3.0
474+
"name": "large_blocks_1x512_sparse0",
475+
"B": 1, "H": 8, "L": 512, "D": 64,
476+
"block_sizes": [512] # Not sparse at all
341477
},
478+
479+
# ===== REAL-WORLD SCENARIOS =====
342480
{
343-
"name": "bert_style_packing",
481+
"name": "bert_base_packing",
344482
"B": 2, "H": 12, "L": 512, "D": 64,
345-
"block_sizes": [128, 128, 128, 128], # BERT-style
346-
"expected_speedup": 1.3
483+
"block_sizes": [128, 128, 128, 128] # BERT-style sequence packing
484+
},
485+
{
486+
"name": "bert_large_packing",
487+
"B": 2, "H": 16, "L": 512, "D": 64,
488+
"block_sizes": [256, 256] # BERT-Large style
489+
},
490+
{
491+
"name": "gpt_style_packing",
492+
"B": 1, "H": 32, "L": 1024, "D": 64,
493+
"block_sizes": [512, 512] # GPT-style long sequences
494+
},
495+
{
496+
"name": "t5_encoder_packing",
497+
"B": 4, "H": 16, "L": 512, "D": 64,
498+
"block_sizes": [128, 128, 128, 128] # T5 encoder style
499+
},
500+
{
501+
"name": "longformer_sparse",
502+
"B": 1, "H": 16, "L": 2048, "D": 64,
503+
"block_sizes": [128] * 16 # Longformer-style local attention
504+
},
505+
506+
# ===== EDGE CASES =====
507+
{
508+
"name": "single_token_blocks",
509+
"B": 1, "H": 8, "L": 64, "D": 64,
510+
"block_sizes": [1] * 64 # Extreme case: every token is its own block
511+
},
512+
{
513+
"name": "uneven_tiny_blocks",
514+
"B": 1, "H": 16, "L": 512, "D": 64,
515+
"block_sizes": [16, 8, 32, 4, 64, 16, 8, 32, 4, 64] * 3 # Uneven tiny blocks
516+
},
517+
{
518+
"name": "power_of_2_progression",
519+
"B": 1, "H": 16, "L": 1024, "D": 64,
520+
"block_sizes": [512, 256, 128, 64, 32, 16, 8, 4, 2, 2] # Powers of 2
521+
},
522+
523+
# ===== PERFORMANCE STRESS TESTS =====
524+
{
525+
"name": "stress_very_long_seq",
526+
"B": 1, "H": 8, "L": 4096, "D": 64,
527+
"block_sizes": [256] * 16 # Very long sequences
528+
},
529+
{
530+
"name": "stress_many_heads",
531+
"B": 1, "H": 128, "L": 512, "D": 64,
532+
"block_sizes": [64] * 8 # Many attention heads
533+
},
534+
{
535+
"name": "stress_large_batch",
536+
"B": 16, "H": 16, "L": 512, "D": 64,
537+
"block_sizes": [64] * 8 # Large batch size
347538
},
348539
{
349-
"name": "large_seq_sparse",
350-
"B": 1, "H": 32, "L": 2048, "D": 64,
351-
"block_sizes": [256] * 8, # Large sequence, 87.5% sparse
352-
"expected_speedup": 2.5
540+
"name": "stress_wide_heads",
541+
"B": 1, "H": 16, "L": 512, "D": 256,
542+
"block_sizes": [64] * 8 # Very wide attention heads
353543
}
354544
]
355545

@@ -372,19 +562,18 @@ def run_block_diagonal_tests(evolved_fn):
372562

373563
# Calculate results
374564
speedup = time_spda / time_evolved if time_evolved > 0 else 0.0
375-
expected = config["expected_speedup"]
376565

377-
# Determine status
566+
# Determine status based on objective performance criteria
378567
if not correctness_ok:
379568
status = "❌ WRONG"
380569
color = "\033[91m" # Red
381-
elif speedup >= expected * 0.8: # Within 80% of expected
570+
elif speedup >= 1.5: # Significant speedup
382571
status = "✅ GOOD"
383572
color = "\033[92m" # Green
384-
elif speedup >= 1.1:
573+
elif speedup >= 1.1: # Modest speedup
385574
status = "⚡ OK"
386575
color = "\033[93m" # Yellow
387-
else:
576+
else: # No meaningful improvement
388577
status = "❌ SLOW"
389578
color = "\033[91m" # Red
390579
reset = "\033[0m"
@@ -399,7 +588,6 @@ def run_block_diagonal_tests(evolved_fn):
399588
block_results.append({
400589
"config": config["name"],
401590
"speedup": speedup,
402-
"expected": expected,
403591
"sparsity": sparsity,
404592
"status": status,
405593
"time_evolved": time_evolved,
@@ -455,7 +643,7 @@ def print_comprehensive_summary(official_results, block_results):
455643
print(f" Worst speedup: {min(block_speedups):.2f}x")
456644

457645
good_results = sum(1 for r in block_results if "✅" in r.get("status", ""))
458-
print(f" Tests meeting expectations: {good_results}/{len(block_results)} ({good_results/len(block_results)*100:.1f}%)")
646+
print(f" Tests with significant speedups: {good_results}/{len(block_results)} ({good_results/len(block_results)*100:.1f}%)")
459647

460648
# Overall assessment
461649
print(f"\n🎖️ OVERALL ASSESSMENT:")

0 commit comments

Comments
 (0)