33import pytest
44import torch
55
6- if not torch .cuda .is_available ():
7- pytest .skip ("CUDA device not available" , allow_module_level = True )
8-
96from lm_saes .config import SAEConfig
107from lm_saes .sae import SparseAutoEncoder
118
@@ -21,6 +18,7 @@ def sae_config() -> SAEConfig:
2118 dtype = torch .float32 ,
2219 act_fn = "topk" ,
2320 jump_relu_threshold = 2.0 ,
21+ top_k = 2 ,
2422 )
2523
2624
@@ -42,7 +40,7 @@ def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder:
4240 )
4341 if sae_config .use_decoder_bias :
4442 sae .decoder .bias .data = torch .randn (
45- sae_config .d_sae , generator = generator , device = sae_config .device , dtype = sae_config .dtype
43+ sae_config .d_model , generator = generator , device = sae_config .device , dtype = sae_config .dtype
4644 )
4745 if sae_config .use_glu_encoder :
4846 sae .encoder_glu .weight .data = torch .randn (
@@ -156,6 +154,20 @@ def test_compute_norm_factor(sae_config: SAEConfig, sae: SparseAutoEncoder):
156154 )
157155
158156
157+ def test_persistent_dataset_average_activation_norm (sae_config : SAEConfig , sae : SparseAutoEncoder ):
158+ sae .set_dataset_average_activation_norm ({"in" : 3.0 , "out" : 2.0 })
159+ assert sae .dataset_average_activation_norm == {"in" : 3.0 , "out" : 2.0 }
160+ state_dict = sae ._get_full_state_dict ()
161+ assert state_dict ["dataset_average_activation_norm.in" ] == 3.0
162+ assert state_dict ["dataset_average_activation_norm.out" ] == 2.0
163+
164+ new_sae = SparseAutoEncoder (sae_config )
165+ new_sae ._load_full_state_dict (state_dict )
166+ assert new_sae .cfg == sae .cfg
167+ assert all (torch .allclose (p , q , atol = 1e-4 , rtol = 1e-5 ) for p , q in zip (new_sae .parameters (), sae .parameters ()))
168+ assert new_sae .dataset_average_activation_norm == {"in" : 3.0 , "out" : 2.0 }
169+
170+
159171def test_get_full_state_dict (sae_config : SAEConfig , sae : SparseAutoEncoder ):
160172 sae_config .sparsity_include_decoder_norm = False
161173 state_dict = sae ._get_full_state_dict ()
@@ -205,3 +217,9 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar
205217 assert torch .allclose (
206218 sae .decoder .bias .data , decoder_bias_data / math .sqrt (sae_config .d_model ) * 2.0 , atol = 1e-4 , rtol = 1e-5
207219 )
220+
221+
222+ def test_forward (sae_config : SAEConfig , sae : SparseAutoEncoder ):
223+ sae .set_dataset_average_activation_norm ({"in" : 3.0 , "out" : 2.0 })
224+ output = sae .forward (torch .tensor ([[1.0 , 2.0 ]], device = sae_config .device , dtype = sae_config .dtype ))
225+ assert output .shape == (1 , 2 )
0 commit comments