Skip to content

Commit 30e8f6c

Browse files
authored
Merge pull request #70 from OpenMOSS/cc_upd
fix misc sae.py issues (batch topk, activation func binary mask imple…
2 parents 9427334 + 8ae78c9 commit 30e8f6c

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

src/lm_saes/sae.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,32 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
9292
return decoder_norm
9393

9494
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
95-
assert cfg.act_fn.lower() in ["relu", "topk", "jumprelu"], f"Not implemented activation function {cfg.act_fn}"
95+
assert cfg.act_fn.lower() in ["relu", "topk", "jumprelu", "batchtopk"], f"Not implemented activation function {cfg.act_fn}"
9696
if cfg.act_fn.lower() == "relu":
97-
return lambda x: torch.where(x > 0, 1, 0)
98-
if cfg.act_fn.lower() == "jumprelu":
99-
return lambda x: torch.where(x > cfg.jump_relu_threshold, 1, 0)
100-
if cfg.act_fn.lower() == "topk":
97+
return lambda x: x.gt(0).float()
98+
elif cfg.act_fn.lower() == "jumprelu":
99+
return lambda x: x.gt(cfg.jump_relu_threshold).float()
100+
elif cfg.act_fn.lower() == "topk":
101101

102102
def topk_activation(x: torch.Tensor):
103103
x = torch.clamp(x, min=0.0)
104104
k = x.shape[-1] - self.current_k + 1
105105
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1)
106-
return torch.where(x >= k_th_value, 1, 0)
106+
return x.ge(k_th_value).float()
107107

108108
return topk_activation
109+
110+
elif cfg.act_fn.lower() == "batchtopk":
111+
def topk_activation(x: torch.Tensor):
112+
assert x.dim() == 2
113+
batch_size = x.size(0)
114+
115+
x = torch.clamp(x, min=0.0)
116+
k = x.numel() - self.current_k * batch_size + 1
117+
k_th_value, _ = torch.kthvalue(x.flatten(), k=k, dim=-1)
118+
return x.ge(k_th_value).float()
119+
120+
return topk_activation
109121

110122
def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore
111123
"""Compute the normalization factor for the activation vectors.
@@ -462,7 +474,7 @@ def compute_loss(
462474
"l_rec": l_rec,
463475
}
464476

465-
if not self.cfg.act_fn == "topk":
477+
if "topk" not in self.cfg.act_fn:
466478
l_lp = torch.norm(feature_acts, p=lp, dim=-1)
467479
loss_dict["l_lp"] = l_lp
468480
assert self.current_l1_coefficient is not None
@@ -522,7 +534,7 @@ def log_statistics(self):
522534
}
523535
if self.cfg.use_decoder_bias:
524536
log_dict["metrics/decoder_bias_norm"] = self.decoder.bias.norm().item()
525-
if self.cfg.act_fn == "topk":
537+
if "topk" in self.cfg.act_fn:
526538
log_dict["sparsity/k"] = self.current_k
527539
else:
528540
log_dict["sparsity/l1_coefficient"] = self.current_l1_coefficient

0 commit comments

Comments
 (0)