Skip to content

Commit 4faf9f6

Browse files
authored
fix(topk activation): add keepdim=True to enable broadcasting; make d… (#73)
* fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode * fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode * fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode
1 parent 24dc841 commit 4faf9f6

File tree

4 files changed

+44
-28
lines changed

4 files changed

+44
-28
lines changed

src/lm_saes/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ def save_hyperparameters(self, sae_path: Path | str, remove_loading_info: bool =
9595

9696

9797
class SAEConfig(BaseSAEConfig):
98-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'sae'
99-
98+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "sae"
99+
100100

101101
class CrossCoderConfig(BaseSAEConfig):
102-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'crosscoder'
103-
102+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "crosscoder"
103+
104104

105105
class MixCoderConfig(BaseSAEConfig):
106-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'mixcoder'
106+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder"
107107
d_single_modal: int
108108
d_shared: int
109109
n_modalities: int = 2

src/lm_saes/crosscoder.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,18 @@ class CrossCoder(SparseAutoEncoder):
2020
def __init__(self, cfg: BaseSAEConfig):
2121
super(CrossCoder, self).__init__(cfg)
2222

23-
def _decoder_norm(
24-
self,
25-
decoder: torch.nn.Linear,
26-
keepdim: bool = False,
27-
local_only=True,
28-
aggregate="none"
29-
):
23+
def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False, local_only=True, aggregate="none"):
3024
decoder_norm = super()._decoder_norm(
3125
decoder=decoder,
3226
keepdim=keepdim,
3327
)
3428
if not local_only:
3529
decoder_norm = all_reduce_tensor(
36-
decoder_norm,
30+
decoder_norm,
3731
aggregate=aggregate,
3832
)
3933
return decoder_norm
40-
34+
4135
@overload
4236
def encode(
4337
self,
@@ -110,7 +104,7 @@ def encode(
110104

111105
hidden_pre = all_reduce_tensor(hidden_pre, aggregate="sum")
112106
hidden_pre = self.hook_hidden_pre(hidden_pre)
113-
107+
114108
if self.cfg.sparsity_include_decoder_norm:
115109
true_feature_acts = hidden_pre * self._decoder_norm(
116110
decoder=self.decoder,
@@ -127,7 +121,7 @@ def encode(
127121
if return_hidden_pre:
128122
return feature_acts, hidden_pre
129123
return feature_acts
130-
124+
131125
@overload
132126
def compute_loss(
133127
self,
@@ -229,4 +223,3 @@ def initialize_with_same_weight_across_layers(self):
229223
self.encoder.bias.data = get_tensor_from_specific_rank(self.encoder.bias.data.clone(), src=0)
230224
self.decoder.weight.data = get_tensor_from_specific_rank(self.decoder.weight.data.clone(), src=0)
231225
self.decoder.bias.data = get_tensor_from_specific_rank(self.decoder.bias.data.clone(), src=0)
232-

src/lm_saes/sae.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,31 +92,37 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
9292
return decoder_norm
9393

9494
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
95-
assert cfg.act_fn.lower() in ["relu", "topk", "jumprelu", "batchtopk"], f"Not implemented activation function {cfg.act_fn}"
95+
assert cfg.act_fn.lower() in [
96+
"relu",
97+
"topk",
98+
"jumprelu",
99+
"batchtopk",
100+
], f"Not implemented activation function {cfg.act_fn}"
96101
if cfg.act_fn.lower() == "relu":
97-
return lambda x: x.gt(0).float()
102+
return lambda x: x.gt(0).to(x.dtype)
98103
elif cfg.act_fn.lower() == "jumprelu":
99-
return lambda x: x.gt(cfg.jump_relu_threshold).float()
104+
return lambda x: x.gt(cfg.jump_relu_threshold).to(x.dtype)
100105
elif cfg.act_fn.lower() == "topk":
101106

102107
def topk_activation(x: torch.Tensor):
103108
x = torch.clamp(x, min=0.0)
104109
k = x.shape[-1] - self.current_k + 1
105-
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1)
106-
return x.ge(k_th_value).float()
110+
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1, keepdim=True)
111+
return x.ge(k_th_value).to(x.dtype)
107112

108113
return topk_activation
109-
114+
110115
elif cfg.act_fn.lower() == "batchtopk":
116+
111117
def topk_activation(x: torch.Tensor):
112118
assert x.dim() == 2
113119
batch_size = x.size(0)
114-
120+
115121
x = torch.clamp(x, min=0.0)
116122
k = x.numel() - self.current_k * batch_size + 1
117123
k_th_value, _ = torch.kthvalue(x.flatten(), k=k, dim=-1)
118-
return x.ge(k_th_value).float()
119-
124+
return x.ge(k_th_value).to(x.dtype)
125+
120126
return topk_activation
121127

122128
def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore

tests/unit/test_sae.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,28 @@ def set_encoder_norm(norm: float):
8888

8989
def test_sae_activate_fn(sae_config: SAEConfig, sae: SparseAutoEncoder):
9090
sae.current_k = 2
91+
print(
92+
sae.activation_function(
93+
torch.tensor(
94+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]],
95+
device=sae_config.device,
96+
dtype=sae_config.dtype,
97+
)
98+
)
99+
)
91100
assert torch.allclose(
92101
sae.activation_function(
93-
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], device=sae_config.device, dtype=sae_config.dtype)
102+
torch.tensor(
103+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [5.0, 6.0, 1.0, 2.0, 3.0, 4.0]],
104+
device=sae_config.device,
105+
dtype=sae_config.dtype,
106+
)
94107
).to(sae_config.device, sae_config.dtype),
95-
torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0]], device=sae_config.device, dtype=sae_config.dtype),
108+
torch.tensor(
109+
[[0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 0.0, 0.0]],
110+
device=sae_config.device,
111+
dtype=sae_config.dtype,
112+
),
96113
atol=1e-4,
97114
rtol=1e-5,
98115
)

0 commit comments

Comments
 (0)