Skip to content

Commit 24dc841

Browse files
authored
Merge pull request #72 from OpenMOSS/cc_upd
re-implement crosscoders
2 parents a5bc151 + 62c34c5 commit 24dc841

File tree

4 files changed

+245
-387
lines changed

4 files changed

+245
-387
lines changed

src/lm_saes/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ class BaseSAEConfig(BaseModelConfig):
4848
So this class should not be used directly but only as a base config class for other SAE variants like SAEConfig, MixCoderConfig, CrossCoderConfig, etc.
4949
"""
5050

51+
sae_type: Literal["sae", "crosscoder", "mixcoder"]
5152
hook_point_in: str
5253
hook_point_out: str = Field(default_factory=lambda validated_model: validated_model["hook_point_in"])
5354
d_model: int
5455
expansion_factor: int
5556
use_decoder_bias: bool = True
5657
use_glu_encoder: bool = False
57-
act_fn: str = "relu"
58+
act_fn: Literal["relu", "jumprelu", "topk", "batchtopk"] = "relu"
5859
jump_relu_threshold: float = 0.0
5960
apply_decoder_bias_to_pre_encoder: bool = True
6061
norm_activation: str = "dataset-wise"
@@ -94,10 +95,15 @@ def save_hyperparameters(self, sae_path: Path | str, remove_loading_info: bool =
9495

9596

9697
class SAEConfig(BaseSAEConfig):
97-
pass
98+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'sae'
99+
98100

101+
class CrossCoderConfig(BaseSAEConfig):
102+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'crosscoder'
103+
99104

100105
class MixCoderConfig(BaseSAEConfig):
106+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'mixcoder'
101107
d_single_modal: int
102108
d_shared: int
103109
n_modalities: int = 2

0 commit comments

Comments
 (0)