Skip to content

Commit a5bc151

Browse files
authored
Merge pull request #71 from OpenMOSS/sae/persistant_dataset_norm
feat(sae): support saving/loading dataset_average_activation_norm to/from SAE state dict
2 parents 30e8f6c + 56d5251 commit a5bc151

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

src/lm_saes/sae.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ def _get_full_state_dict(self): # should be overridden by subclasses
196196
if self.device_mesh and self.device_mesh["model"].size(0) > 1:
197197
state_dict = {k: v.full_tensor() if isinstance(v, DTensor) else v for k, v in state_dict.items()}
198198

199+
# Add dataset_average_activation_norm to state dict
200+
if self.dataset_average_activation_norm is not None:
201+
for hook_point, value in self.dataset_average_activation_norm.items():
202+
state_dict[f"dataset_average_activation_norm.{hook_point}"] = torch.tensor(value)
203+
199204
# If sparsity_include_decoder_norm is False, we need to normalize the decoder weight before saving
200205
# We use a deepcopy to avoid modifying the original weight to avoid affecting the training progress
201206
if not self.cfg.sparsity_include_decoder_norm:
@@ -489,6 +494,15 @@ def compute_loss(
489494
return loss, (loss_dict, aux_data)
490495
return loss
491496

497+
def _load_full_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
498+
# Extract and set dataset_average_activation_norm if present
499+
norm_keys = [k for k in state_dict.keys() if k.startswith("dataset_average_activation_norm.")]
500+
if norm_keys:
501+
dataset_norm = {key.split(".", 1)[1]: state_dict[key].item() for key in norm_keys}
502+
self.set_dataset_average_activation_norm(dataset_norm)
503+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("dataset_average_activation_norm.")}
504+
self.load_state_dict(state_dict, strict=self.cfg.strict_loading)
505+
492506
@classmethod
493507
def from_config(cls, cfg: SAEConfig) -> "SparseAutoEncoder":
494508
if cfg.sae_pretrained_name_or_path is None:
@@ -512,12 +526,16 @@ def from_config(cls, cfg: SAEConfig) -> "SparseAutoEncoder":
512526
raise FileNotFoundError(f"Pretrained model not found at {cfg.sae_pretrained_name_or_path}")
513527

514528
if ckpt_path.endswith(".safetensors"):
515-
state_dict = safe.load_file(ckpt_path, device=cfg.device)
529+
state_dict: dict[str, torch.Tensor] = safe.load_file(ckpt_path, device=cfg.device)
516530
else:
517-
state_dict = torch.load(ckpt_path, map_location=cfg.device)["sae"]
531+
state_dict: dict[str, torch.Tensor] = torch.load(
532+
ckpt_path,
533+
map_location=cfg.device,
534+
weights_only=True,
535+
)["sae"]
518536

519537
model = cls(cfg)
520-
model.load_state_dict(state_dict, strict=cfg.strict_loading)
538+
model._load_full_state_dict(state_dict)
521539
return model
522540

523541
@classmethod

tests/unit/test_sae.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
import pytest
44
import torch
55

6-
if not torch.cuda.is_available():
7-
pytest.skip("CUDA device not available", allow_module_level=True)
8-
96
from lm_saes.config import SAEConfig
107
from lm_saes.sae import SparseAutoEncoder
118

@@ -21,6 +18,7 @@ def sae_config() -> SAEConfig:
2118
dtype=torch.float32,
2219
act_fn="topk",
2320
jump_relu_threshold=2.0,
21+
top_k=2,
2422
)
2523

2624

@@ -42,7 +40,7 @@ def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder:
4240
)
4341
if sae_config.use_decoder_bias:
4442
sae.decoder.bias.data = torch.randn(
45-
sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
43+
sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
4644
)
4745
if sae_config.use_glu_encoder:
4846
sae.encoder_glu.weight.data = torch.randn(
@@ -156,6 +154,20 @@ def test_compute_norm_factor(sae_config: SAEConfig, sae: SparseAutoEncoder):
156154
)
157155

158156

157+
def test_persistent_dataset_average_activation_norm(sae_config: SAEConfig, sae: SparseAutoEncoder):
158+
sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0})
159+
assert sae.dataset_average_activation_norm == {"in": 3.0, "out": 2.0}
160+
state_dict = sae._get_full_state_dict()
161+
assert state_dict["dataset_average_activation_norm.in"] == 3.0
162+
assert state_dict["dataset_average_activation_norm.out"] == 2.0
163+
164+
new_sae = SparseAutoEncoder(sae_config)
165+
new_sae._load_full_state_dict(state_dict)
166+
assert new_sae.cfg == sae.cfg
167+
assert all(torch.allclose(p, q, atol=1e-4, rtol=1e-5) for p, q in zip(new_sae.parameters(), sae.parameters()))
168+
assert new_sae.dataset_average_activation_norm == {"in": 3.0, "out": 2.0}
169+
170+
159171
def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder):
160172
sae_config.sparsity_include_decoder_norm = False
161173
state_dict = sae._get_full_state_dict()
@@ -205,3 +217,9 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar
205217
assert torch.allclose(
206218
sae.decoder.bias.data, decoder_bias_data / math.sqrt(sae_config.d_model) * 2.0, atol=1e-4, rtol=1e-5
207219
)
220+
221+
222+
def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
223+
sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0})
224+
output = sae.forward(torch.tensor([[1.0, 2.0]], device=sae_config.device, dtype=sae_config.dtype))
225+
assert output.shape == (1, 2)

0 commit comments

Comments
 (0)