Skip to content

Commit 1fb85a1

Browse files
Switch to argsort over TopK (#265)
Using argsort over TopK seems to use less memory and run faster. I found `0.6873465579992626` for argsort and `0.7168450749959447` for TopK. This should be expected behavior for a flattened tensor. Co-authored-by: Azazelle Guice <[email protected]>
1 parent 7467108 commit 1fb85a1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mergekit/sparsify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens
4848
w = tensor.abs().view(-1)
4949
if w.device.type == "cpu":
5050
w = w.float()
51-
topk = torch.topk(w, k=k, largest=True)
52-
mask.view(-1)[topk.indices] = 1
51+
topk = torch.argsort(w, descending=True)[:k]
52+
mask.view(-1)[topk] = 1
5353

5454
if rescale:
5555
res = rescale_sum(tensor, mask)

0 commit comments

Comments
 (0)