@@ -461,6 +461,91 @@ def compute_irrelevant(
461
461
}
462
462
463
463
464
+ @torch .no_grad ()
465
+ def compute_basic_interpretation_axis_limits (
466
+ model : HookedTransformer ,
467
+ * ,
468
+ include_uncentered : bool = False ,
469
+ include_equals_OV : bool = False ,
470
+ includes_eos : Optional [bool ] = None ,
471
+ plot_with : Literal ["plotly" , "matplotlib" ] = "plotly" ,
472
+ ) -> Tuple [dict , dict [str , float ]]:
473
+ cached_data = {}
474
+ axis_limits = {
475
+ "OV_zmin" : np .inf ,
476
+ "OV_zmax" : - np .inf ,
477
+ "QK_zmin" : np .inf ,
478
+ "QK_zmax" : - np .inf ,
479
+ "OVCentered_zmin" : np .inf ,
480
+ "OVCentered_zmax" : - np .inf ,
481
+ "QKWithAttnScale_zmin" : np .inf ,
482
+ "QKWithAttnScale_zmax" : - np .inf ,
483
+ }
484
+ if includes_eos is None :
485
+ includes_eos = model .cfg .d_vocab != model .cfg .d_vocab_out
486
+ title_kind = "html" if plot_with == "plotly" else "latex"
487
+ for attn_scale , with_attn_scale in (("" , False ), ("WithAttnScale" , True )):
488
+ QK = compute_QK (
489
+ model , includes_eos = includes_eos , with_attn_scale = with_attn_scale
490
+ )
491
+ axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
492
+ [axis_limits [f"QK{ attn_scale } _zmin" ], QK ["data" ].min ()]
493
+ )
494
+ axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
495
+ [axis_limits [f"QK{ attn_scale } _zmax" ], QK ["data" ].max ()]
496
+ )
497
+ cached_data [("QK" , attn_scale )] = QK
498
+
499
+ if include_uncentered :
500
+ OV = compute_OV (model , centered = False , includes_eos = includes_eos )
501
+ axis_limits ["OV_zmin" ] = np .min ([axis_limits ["OV_zmin" ], OV ["data" ].min ()])
502
+ axis_limits ["OV_zmax" ] = np .max ([axis_limits ["OV_zmax" ], OV ["data" ].max ()])
503
+ cached_data [("OV" , False )] = OV
504
+
505
+ OV = compute_OV (model , centered = True , includes_eos = includes_eos )
506
+ axis_limits ["OVCentered_zmin" ] = np .min (
507
+ [axis_limits ["OVCentered_zmin" ], OV ["data" ].min ()]
508
+ )
509
+ axis_limits ["OVCentered_zmax" ] = np .max (
510
+ [axis_limits ["OVCentered_zmax" ], OV ["data" ].max ()]
511
+ )
512
+ cached_data [("OV" , True )] = OV
513
+
514
+ for attn_scale , with_attn_scale in (("" , False ), ("WithAttnScale" , True )):
515
+ pos_QK = compute_QK_by_position (
516
+ model , includes_eos = includes_eos , with_attn_scale = with_attn_scale
517
+ )
518
+ cached_data [("pos_QK" , attn_scale )] = pos_QK
519
+ if includes_eos :
520
+ axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
521
+ [axis_limits [f"QK{ attn_scale } _zmin" ], pos_QK ["data" ]["QK" ].min ()]
522
+ )
523
+ axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
524
+ [axis_limits [f"QK{ attn_scale } _zmax" ], pos_QK ["data" ]["QK" ].max ()]
525
+ )
526
+ else :
527
+ axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
528
+ [axis_limits [f"QK{ attn_scale } _zmin" ], pos_QK ["data" ]["QK" ].min ()]
529
+ )
530
+ axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
531
+ [axis_limits [f"QK{ attn_scale } _zmax" ], pos_QK ["data" ]["QK" ].max ()]
532
+ )
533
+
534
+ irrelevant = compute_irrelevant (
535
+ model ,
536
+ include_equals_OV = include_equals_OV ,
537
+ includes_eos = includes_eos ,
538
+ title_kind = title_kind ,
539
+ )
540
+ cached_data ["irrelevant" ] = irrelevant
541
+ for key , data in irrelevant ["data" ].items ():
542
+ if len (data .shape ) == 2 :
543
+ axis_limits ["OV_zmin" ] = np .min ([axis_limits ["OV_zmin" ], data .min ()])
544
+ axis_limits ["OV_zmax" ] = np .max ([axis_limits ["OV_zmax" ], data .max ()])
545
+
546
+ return cached_data , axis_limits
547
+
548
+
464
549
@torch .no_grad ()
465
550
def display_basic_interpretation (
466
551
model : HookedTransformer ,
@@ -485,34 +570,26 @@ def display_basic_interpretation(
485
570
plot_with : Literal ["plotly" , "matplotlib" ] = "plotly" ,
486
571
renderer : Optional [str ] = None ,
487
572
show : bool = True ,
573
+ cached_data : Optional [dict ] = None ,
574
+ axis_limits : Optional [dict [str , float ]] = None ,
488
575
) -> Tuple [dict [str , Union [go .Figure , matplotlib .figure .Figure ]], dict [str , float ]]:
576
+ if cached_data is None :
577
+ cached_data , axis_limits = compute_basic_interpretation_axis_limits (
578
+ model ,
579
+ include_uncentered = include_uncentered ,
580
+ include_equals_OV = include_equals_OV ,
581
+ includes_eos = includes_eos ,
582
+ plot_with = plot_with ,
583
+ )
489
584
QK_cmap = colorscale_to_cmap (QK_colorscale )
490
585
QK_SVD_cmap = colorscale_to_cmap (QK_SVD_colorscale )
491
586
OV_cmap = colorscale_to_cmap (OV_colorscale )
492
587
if includes_eos is None :
493
588
includes_eos = model .cfg .d_vocab != model .cfg .d_vocab_out
494
589
result = {}
495
- axis_limits = {
496
- "OV_zmin" : np .inf ,
497
- "OV_zmax" : - np .inf ,
498
- "QK_zmin" : np .inf ,
499
- "QK_zmax" : - np .inf ,
500
- "OVCentered_zmin" : np .inf ,
501
- "OVCentered_zmax" : - np .inf ,
502
- "QKWithAttnScale_zmin" : np .inf ,
503
- "QKWithAttnScale_zmax" : - np .inf ,
504
- }
590
+ title_kind = "html" if plot_with == "plotly" else "latex"
505
591
for attn_scale , with_attn_scale in (("" , False ), ("WithAttnScale" , True )):
506
- QK = compute_QK (
507
- model , includes_eos = includes_eos , with_attn_scale = with_attn_scale
508
- )
509
- axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
510
- [axis_limits [f"QK{ attn_scale } _zmin" ], QK ["data" ].min ()]
511
- )
512
- axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
513
- [axis_limits [f"QK{ attn_scale } _zmax" ], QK ["data" ].max ()]
514
- )
515
- title_kind = "html" if plot_with == "plotly" else "latex"
592
+ QK = cached_data [("QK" , attn_scale )]
516
593
if includes_eos :
517
594
match plot_with :
518
595
case "plotly" :
@@ -567,9 +644,7 @@ def display_basic_interpretation(
567
644
result [f"EQKE{ attn_scale } " ] = fig_qk
568
645
569
646
if include_uncentered :
570
- OV = compute_OV (model , centered = False , includes_eos = includes_eos )
571
- axis_limits ["OV_zmin" ] = np .min ([axis_limits ["OV_zmin" ], OV ["data" ].min ()])
572
- axis_limits ["OV_zmax" ] = np .max ([axis_limits ["OV_zmax" ], OV ["data" ].max ()])
647
+ OV = cached_data [("OV" , False )]
573
648
fig_ov = imshow (
574
649
OV ["data" ],
575
650
title = OV ["title" ][title_kind ],
@@ -585,13 +660,7 @@ def display_basic_interpretation(
585
660
show = show ,
586
661
)
587
662
result ["EVOU" ] = fig_ov
588
- OV = compute_OV (model , centered = True , includes_eos = includes_eos )
589
- axis_limits ["OVCentered_zmin" ] = np .min (
590
- [axis_limits ["OVCentered_zmin" ], OV ["data" ].min ()]
591
- )
592
- axis_limits ["OVCentered_zmax" ] = np .max (
593
- [axis_limits ["OVCentered_zmax" ], OV ["data" ].max ()]
594
- )
663
+ OV = cached_data [("OV" , True )]
595
664
fig_ov = imshow (
596
665
OV ["data" ],
597
666
title = OV ["title" ][title_kind ],
@@ -609,16 +678,8 @@ def display_basic_interpretation(
609
678
result ["EVOU-centered" ] = fig_ov
610
679
611
680
for attn_scale , with_attn_scale in (("" , False ), ("WithAttnScale" , True )):
612
- pos_QK = compute_QK_by_position (
613
- model , includes_eos = includes_eos , with_attn_scale = with_attn_scale
614
- )
681
+ pos_QK = cached_data [("pos_QK" , attn_scale )]
615
682
if includes_eos :
616
- axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
617
- [axis_limits [f"QK{ attn_scale } _zmin" ], pos_QK ["data" ]["QK" ].min ()]
618
- )
619
- axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
620
- [axis_limits [f"QK{ attn_scale } _zmax" ], pos_QK ["data" ]["QK" ].max ()]
621
- )
622
683
fig_qk = px .scatter (
623
684
pos_QK ["data" ],
624
685
title = pos_QK ["title" ][title_kind ],
@@ -631,12 +692,6 @@ def display_basic_interpretation(
631
692
if show :
632
693
fig_qk .show (renderer = renderer )
633
694
else :
634
- axis_limits [f"QK{ attn_scale } _zmin" ] = np .min (
635
- [axis_limits [f"QK{ attn_scale } _zmin" ], pos_QK ["data" ]["QK" ].min ()]
636
- )
637
- axis_limits [f"QK{ attn_scale } _zmax" ] = np .max (
638
- [axis_limits [f"QK{ attn_scale } _zmax" ], pos_QK ["data" ]["QK" ].max ()]
639
- )
640
695
fig_qk = imshow (
641
696
pos_QK ["data" ]["QK" ],
642
697
title = pos_QK ["title" ][title_kind ],
@@ -653,16 +708,9 @@ def display_basic_interpretation(
653
708
)
654
709
result [f"EQKP{ attn_scale } " ] = fig_qk
655
710
656
- irrelevant = compute_irrelevant (
657
- model ,
658
- include_equals_OV = include_equals_OV ,
659
- includes_eos = includes_eos ,
660
- title_kind = title_kind ,
661
- )
711
+ irrelevant = cached_data ["irrelevant" ]
662
712
for key , data in irrelevant ["data" ].items ():
663
713
if len (data .shape ) == 2 :
664
- axis_limits ["OV_zmin" ] = np .min ([axis_limits ["OV_zmin" ], data .min ()])
665
- axis_limits ["OV_zmax" ] = np .max ([axis_limits ["OV_zmax" ], data .max ()])
666
714
fig = imshow (
667
715
data ,
668
716
title = key ,
0 commit comments