@@ -313,43 +313,233 @@ def run_block_diagonal_tests(evolved_fn):
313313 print ("Format: Test | Shape | Blocks | Sparsity | Evolved | SPDA | Speedup | Status" )
314314 print ("-" * 80 )
315315
316- # Block-diagonal test configurations
316+ # Block-diagonal test configurations - comprehensive coverage
317317 block_configs = [
318+ # ===== BASIC SPARSITY PROGRESSION =====
318319 {
319- "name" : "packed_2x256_sparse50 " ,
320+ "name" : "dense_2x256_sparse50 " ,
320321 "B" : 1 , "H" : 8 , "L" : 512 , "D" : 64 ,
321- "block_sizes" : [256 , 256 ], # 50% sparse
322- "expected_speedup" : 1.2
322+ "block_sizes" : [256 , 256 ] # 50% sparse - baseline
323323 },
324324 {
325- "name" : "packed_4x128_sparse75 " ,
325+ "name" : "medium_4x128_sparse75 " ,
326326 "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
327- "block_sizes" : [128 , 128 , 128 , 128 ], # 75% sparse
328- "expected_speedup" : 1.5
327+ "block_sizes" : [128 , 128 , 128 , 128 ] # 75% sparse
329328 },
330329 {
331- "name" : "packed_8x128_sparse87" ,
330+ "name" : "sparse_8x64_sparse87" ,
331+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
332+ "block_sizes" : [64 ] * 8 # 87.5% sparse
333+ },
334+ {
335+ "name" : "very_sparse_16x32_sparse93" ,
336+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
337+ "block_sizes" : [32 ] * 16 # 93.75% sparse
338+ },
339+ {
340+ "name" : "extreme_sparse_32x16_sparse96" ,
341+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
342+ "block_sizes" : [16 ] * 32 # 96.875% sparse
343+ },
344+
345+ # ===== DIFFERENT SEQUENCE LENGTHS =====
346+ {
347+ "name" : "small_seq_4x32_sparse75" ,
348+ "B" : 1 , "H" : 8 , "L" : 128 , "D" : 64 ,
349+ "block_sizes" : [32 , 32 , 32 , 32 ] # Small sequences
350+ },
351+ {
352+ "name" : "medium_seq_8x64_sparse87" ,
353+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
354+ "block_sizes" : [64 ] * 8 # Medium sequences
355+ },
356+ {
357+ "name" : "large_seq_8x128_sparse87" ,
358+ "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 64 ,
359+ "block_sizes" : [128 ] * 8 # Large sequences
360+ },
361+ {
362+ "name" : "huge_seq_16x128_sparse93" ,
363+ "B" : 1 , "H" : 32 , "L" : 2048 , "D" : 64 ,
364+ "block_sizes" : [128 ] * 16 # Very large sequences
365+ },
366+ {
367+ "name" : "giant_seq_32x64_sparse96" ,
368+ "B" : 1 , "H" : 32 , "L" : 2048 , "D" : 64 ,
369+ "block_sizes" : [64 ] * 32 # Extreme sequences
370+ },
371+
372+ # ===== DIFFERENT HEAD DIMENSIONS =====
373+ {
374+ "name" : "head64_8x64_sparse87" ,
375+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
376+ "block_sizes" : [64 ] * 8 # Standard head dim
377+ },
378+ {
379+ "name" : "head80_8x64_sparse87" ,
380+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 80 ,
381+ "block_sizes" : [64 ] * 8 # PaLM head dim
382+ },
383+ {
384+ "name" : "head128_8x64_sparse87" ,
385+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 128 ,
386+ "block_sizes" : [64 ] * 8 # Large head dim
387+ },
388+ {
389+ "name" : "head32_8x64_sparse87" ,
390+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 32 ,
391+ "block_sizes" : [64 ] * 8 # Small head dim
392+ },
393+
394+ # ===== MIXED BLOCK SIZES =====
395+ {
396+ "name" : "mixed_sizes_pyramid" ,
397+ "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 64 ,
398+ "block_sizes" : [512 , 256 , 128 , 64 , 32 , 16 , 8 , 8 ] # Pyramid pattern
399+ },
400+ {
401+ "name" : "mixed_sizes_alternating" ,
332402 "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 64 ,
333- "block_sizes" : [128 ] * 8 , # 87.5% sparse
334- "expected_speedup" : 2.0
403+ "block_sizes" : [128 , 64 , 128 , 64 , 128 , 64 , 128 , 64 , 128 , 64 ] # Alternating
404+ },
405+ {
406+ "name" : "mixed_sizes_bimodal" ,
407+ "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 64 ,
408+ "block_sizes" : [256 , 256 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 , 32 ] # Two large + many small
409+ },
410+
411+ # ===== BATCH SIZE VARIATIONS =====
412+ {
413+ "name" : "batch1_8x64_sparse87" ,
414+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
415+ "block_sizes" : [64 ] * 8 # Single batch
416+ },
417+ {
418+ "name" : "batch2_8x64_sparse87" ,
419+ "B" : 2 , "H" : 16 , "L" : 512 , "D" : 64 ,
420+ "block_sizes" : [64 ] * 8 # Small batch
421+ },
422+ {
423+ "name" : "batch4_8x64_sparse87" ,
424+ "B" : 4 , "H" : 16 , "L" : 512 , "D" : 64 ,
425+ "block_sizes" : [64 ] * 8 # Medium batch
426+ },
427+ {
428+ "name" : "batch8_8x64_sparse87" ,
429+ "B" : 8 , "H" : 16 , "L" : 512 , "D" : 64 ,
430+ "block_sizes" : [64 ] * 8 # Large batch
431+ },
432+
433+ # ===== HEAD COUNT VARIATIONS =====
434+ {
435+ "name" : "heads4_8x64_sparse87" ,
436+ "B" : 1 , "H" : 4 , "L" : 512 , "D" : 64 ,
437+ "block_sizes" : [64 ] * 8 # Few heads
438+ },
439+ {
440+ "name" : "heads16_8x64_sparse87" ,
441+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
442+ "block_sizes" : [64 ] * 8 # Standard heads
443+ },
444+ {
445+ "name" : "heads32_8x64_sparse87" ,
446+ "B" : 1 , "H" : 32 , "L" : 512 , "D" : 64 ,
447+ "block_sizes" : [64 ] * 8 # Many heads
448+ },
449+ {
450+ "name" : "heads64_8x64_sparse87" ,
451+ "B" : 1 , "H" : 64 , "L" : 512 , "D" : 64 ,
452+ "block_sizes" : [64 ] * 8 # Very many heads
453+ },
454+
455+ # ===== TINY BLOCKS (EXTREME SPARSITY) =====
456+ {
457+ "name" : "tiny_blocks_64x8_sparse98" ,
458+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
459+ "block_sizes" : [8 ] * 64 # 98.4% sparse
460+ },
461+ {
462+ "name" : "tiny_blocks_128x4_sparse99" ,
463+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
464+ "block_sizes" : [4 ] * 128 # 99.2% sparse
465+ },
466+
467+ # ===== LARGE BLOCKS (DENSE PATTERNS) =====
468+ {
469+ "name" : "large_blocks_2x256_sparse50" ,
470+ "B" : 1 , "H" : 8 , "L" : 512 , "D" : 64 ,
471+ "block_sizes" : [256 , 256 ] # Only 50% sparse
335472 },
336473 {
337- "name" : "packed_16x64_sparse93" ,
338- "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 128 ,
339- "block_sizes" : [64 ] * 16 , # 93.75% sparse
340- "expected_speedup" : 3.0
474+ "name" : "large_blocks_1x512_sparse0" ,
475+ "B" : 1 , "H" : 8 , "L" : 512 , "D" : 64 ,
476+ "block_sizes" : [512 ] # Not sparse at all
341477 },
478+
479+ # ===== REAL-WORLD SCENARIOS =====
342480 {
343- "name" : "bert_style_packing " ,
481+ "name" : "bert_base_packing " ,
344482 "B" : 2 , "H" : 12 , "L" : 512 , "D" : 64 ,
345- "block_sizes" : [128 , 128 , 128 , 128 ], # BERT-style
346- "expected_speedup" : 1.3
483+ "block_sizes" : [128 , 128 , 128 , 128 ] # BERT-style sequence packing
484+ },
485+ {
486+ "name" : "bert_large_packing" ,
487+ "B" : 2 , "H" : 16 , "L" : 512 , "D" : 64 ,
488+ "block_sizes" : [256 , 256 ] # BERT-Large style
489+ },
490+ {
491+ "name" : "gpt_style_packing" ,
492+ "B" : 1 , "H" : 32 , "L" : 1024 , "D" : 64 ,
493+ "block_sizes" : [512 , 512 ] # GPT-style long sequences
494+ },
495+ {
496+ "name" : "t5_encoder_packing" ,
497+ "B" : 4 , "H" : 16 , "L" : 512 , "D" : 64 ,
498+ "block_sizes" : [128 , 128 , 128 , 128 ] # T5 encoder style
499+ },
500+ {
501+ "name" : "longformer_sparse" ,
502+ "B" : 1 , "H" : 16 , "L" : 2048 , "D" : 64 ,
503+ "block_sizes" : [128 ] * 16 # Longformer-style local attention
504+ },
505+
506+ # ===== EDGE CASES =====
507+ {
508+ "name" : "single_token_blocks" ,
509+ "B" : 1 , "H" : 8 , "L" : 64 , "D" : 64 ,
510+ "block_sizes" : [1 ] * 64 # Extreme case: every token is its own block
511+ },
512+ {
513+ "name" : "uneven_tiny_blocks" ,
514+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 64 ,
515+ "block_sizes" : [16 , 8 , 32 , 4 , 64 , 16 , 8 , 32 , 4 , 64 ] * 3 # Uneven tiny blocks
516+ },
517+ {
518+ "name" : "power_of_2_progression" ,
519+ "B" : 1 , "H" : 16 , "L" : 1024 , "D" : 64 ,
520+ "block_sizes" : [512 , 256 , 128 , 64 , 32 , 16 , 8 , 4 , 2 , 2 ] # Powers of 2
521+ },
522+
523+ # ===== PERFORMANCE STRESS TESTS =====
524+ {
525+ "name" : "stress_very_long_seq" ,
526+ "B" : 1 , "H" : 8 , "L" : 4096 , "D" : 64 ,
527+ "block_sizes" : [256 ] * 16 # Very long sequences
528+ },
529+ {
530+ "name" : "stress_many_heads" ,
531+ "B" : 1 , "H" : 128 , "L" : 512 , "D" : 64 ,
532+ "block_sizes" : [64 ] * 8 # Many attention heads
533+ },
534+ {
535+ "name" : "stress_large_batch" ,
536+ "B" : 16 , "H" : 16 , "L" : 512 , "D" : 64 ,
537+ "block_sizes" : [64 ] * 8 # Large batch size
347538 },
348539 {
349- "name" : "large_seq_sparse" ,
350- "B" : 1 , "H" : 32 , "L" : 2048 , "D" : 64 ,
351- "block_sizes" : [256 ] * 8 , # Large sequence, 87.5% sparse
352- "expected_speedup" : 2.5
540+ "name" : "stress_wide_heads" ,
541+ "B" : 1 , "H" : 16 , "L" : 512 , "D" : 256 ,
542+ "block_sizes" : [64 ] * 8 # Very wide attention heads
353543 }
354544 ]
355545
@@ -372,19 +562,18 @@ def run_block_diagonal_tests(evolved_fn):
372562
373563 # Calculate results
374564 speedup = time_spda / time_evolved if time_evolved > 0 else 0.0
375- expected = config ["expected_speedup" ]
376565
377- # Determine status
566+ # Determine status based on objective performance criteria
378567 if not correctness_ok :
379568 status = "❌ WRONG"
380569 color = "\033 [91m" # Red
381- elif speedup >= expected * 0.8 : # Within 80% of expected
570+ elif speedup >= 1.5 : # Significant speedup
382571 status = "✅ GOOD"
383572 color = "\033 [92m" # Green
384- elif speedup >= 1.1 :
573+ elif speedup >= 1.1 : # Modest speedup
385574 status = "⚡ OK"
386575 color = "\033 [93m" # Yellow
387- else :
576+ else : # No meaningful improvement
388577 status = "❌ SLOW"
389578 color = "\033 [91m" # Red
390579 reset = "\033 [0m"
@@ -399,7 +588,6 @@ def run_block_diagonal_tests(evolved_fn):
399588 block_results .append ({
400589 "config" : config ["name" ],
401590 "speedup" : speedup ,
402- "expected" : expected ,
403591 "sparsity" : sparsity ,
404592 "status" : status ,
405593 "time_evolved" : time_evolved ,
@@ -455,7 +643,7 @@ def print_comprehensive_summary(official_results, block_results):
455643 print (f" Worst speedup: { min (block_speedups ):.2f} x" )
456644
457645 good_results = sum (1 for r in block_results if "✅" in r .get ("status" , "" ))
458- print (f" Tests meeting expectations : { good_results } /{ len (block_results )} ({ good_results / len (block_results )* 100 :.1f} %)" )
646+ print (f" Tests with significant speedups : { good_results } /{ len (block_results )} ({ good_results / len (block_results )* 100 :.1f} %)" )
459647
460648 # Overall assessment
461649 print (f"\n 🎖️ OVERALL ASSESSMENT:" )
0 commit comments