@@ -92,20 +92,32 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
9292 return decoder_norm
9393
9494 def activation_function_factory (self , cfg : BaseSAEConfig ) -> Callable [[torch .Tensor ], torch .Tensor ]: # type: ignore
95- assert cfg .act_fn .lower () in ["relu" , "topk" , "jumprelu" ], f"Not implemented activation function { cfg .act_fn } "
95+ assert cfg .act_fn .lower () in ["relu" , "topk" , "jumprelu" , "batchtopk" ], f"Not implemented activation function { cfg .act_fn } "
9696 if cfg .act_fn .lower () == "relu" :
97- return lambda x : torch . where ( x > 0 , 1 , 0 )
98- if cfg .act_fn .lower () == "jumprelu" :
99- return lambda x : torch . where ( x > cfg .jump_relu_threshold , 1 , 0 )
100- if cfg .act_fn .lower () == "topk" :
97+ return lambda x : x . gt ( 0 ). float ( )
98+ elif cfg .act_fn .lower () == "jumprelu" :
99+ return lambda x : x . gt ( cfg .jump_relu_threshold ). float ( )
100+ elif cfg .act_fn .lower () == "topk" :
101101
102102 def topk_activation (x : torch .Tensor ):
103103 x = torch .clamp (x , min = 0.0 )
104104 k = x .shape [- 1 ] - self .current_k + 1
105105 k_th_value , _ = torch .kthvalue (x , k = k , dim = - 1 )
106- return torch . where ( x >= k_th_value , 1 , 0 )
106+ return x . ge ( k_th_value ). float ( )
107107
108108 return topk_activation
109+
110+ elif cfg .act_fn .lower () == "batchtopk" :
111+ def topk_activation (x : torch .Tensor ):
112+ assert x .dim () == 2
113+ batch_size = x .size (0 )
114+
115+ x = torch .clamp (x , min = 0.0 )
116+ k = x .numel () - self .current_k * batch_size + 1
117+ k_th_value , _ = torch .kthvalue (x .flatten (), k = k , dim = - 1 )
118+ return x .ge (k_th_value ).float ()
119+
120+ return topk_activation
109121
110122 def compute_norm_factor (self , x : torch .Tensor , hook_point : str ) -> torch .Tensor : # type: ignore
111123 """Compute the normalization factor for the activation vectors.
@@ -462,7 +474,7 @@ def compute_loss(
462474 "l_rec" : l_rec ,
463475 }
464476
465- if not self .cfg .act_fn == "topk" :
477+ if "topk" not in self .cfg .act_fn :
466478 l_lp = torch .norm (feature_acts , p = lp , dim = - 1 )
467479 loss_dict ["l_lp" ] = l_lp
468480 assert self .current_l1_coefficient is not None
@@ -522,7 +534,7 @@ def log_statistics(self):
522534 }
523535 if self .cfg .use_decoder_bias :
524536 log_dict ["metrics/decoder_bias_norm" ] = self .decoder .bias .norm ().item ()
525- if self .cfg .act_fn == "topk" :
537+ if "topk" in self .cfg .act_fn :
526538 log_dict ["sparsity/k" ] = self .current_k
527539 else :
528540 log_dict ["sparsity/l1_coefficient" ] = self .current_l1_coefficient
0 commit comments