Skip to content

Commit 66507ba

Browse files
author
james
committed
bugfix
1 parent 80184c1 commit 66507ba

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

teal/grab_acts.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,9 @@
5555
print(len(text))
5656
bsz, seq_len = 10, 2048
5757

58-
input_ids = []
59-
for i in range(0, len(text), seq_len):
60-
ttext = text[i:i+seq_len]
61-
encodings = tokenizer(ttext, truncation=True, return_tensors="pt", max_length=seq_len, return_overflowing_tokens=True, padding="max_length")
62-
input_ids.append(encodings.input_ids)
58+
encodings = tokenizer(text, truncation=True, return_tensors="pt", max_length=seq_len, return_overflowing_tokens=True, padding="max_length")
6359

64-
input_ids = torch.cat(input_ids, dim=0)[:bsz,:].to(device="cuda:0")
60+
input_ids = encodings.input_ids[:bsz,:].to(device="cuda:0")
6561
print(input_ids.shape)
6662

6763
hidden_states = model.model.embed_tokens(input_ids)

utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def forward(self, x):
4949
return self.apply(x)
5050

5151
def apply(self, x):
52+
nonzero = (x.abs() > self.threshold).sum()
53+
print(f"Nonzero proportion: {nonzero / x.numel()}")
5254
return x.abs().gt(self.threshold) * x
5355

5456
def get_threshold(self):
@@ -222,7 +224,7 @@ def get_sparse_model(model_name, device, histogram_path, **kwargs):
222224

223225
def get_tokenizer(tokenizer_name):
224226
tokenizer = transformers.AutoTokenizer.from_pretrained(
225-
tokenizer_name, use_fast=False, trust_remote_code=True
227+
tokenizer_name, use_fast=True, trust_remote_code=True
226228
)
227229

228230
if tokenizer.pad_token_id is None:

0 commit comments

Comments
 (0)