Skip to content

Commit 8a2803c

Browse files
committed
feat(language-model): introduce Matrix and Node classes for enhanced matrix operations
- Add `Matrix` class to manage a dynamic matrix with source and target node mappings. - Introduce `Node` dataclass to encapsulate node information and indices. - Implement methods for adding source and target nodes, updating the matrix, and retrieving submatrices. - Enhance type hints and docstrings for improved clarity and maintainability. - Add unit tests for `Matrix` class to validate functionality and ensure correctness.
1 parent ab88009 commit 8a2803c

File tree

2 files changed

+210
-1
lines changed

2 files changed

+210
-1
lines changed

src/lm_saes/backend/language_model.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
import json
45
import os
56
import re
@@ -8,7 +9,7 @@
89
from contextlib import contextmanager
910
from functools import partial
1011
from itertools import accumulate
11-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
12+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Self, Union, cast
1213

1314
import einops
1415
import torch
@@ -348,6 +349,132 @@ def detach_hook_fn(x: torch.Tensor, hook: HookPoint):
348349
return detach_hooks
349350

350351

352+
@dataclass
353+
class Node:
354+
key: Any
355+
indices: torch.Tensor
356+
matrix_indices: torch.Tensor
357+
358+
359+
NodeInfo = tuple[Any, torch.Tensor]
360+
361+
362+
class Matrix:
363+
def __init__(self):
364+
self.matrix = torch.zeros(0, 0, dtype=torch.float32, device="cpu")
365+
self.source = {}
366+
self.target = {}
367+
368+
def _add_elements(self, node_infos: list[NodeInfo], dim: int):
369+
node_dict = self.source if dim == 0 else self.target
370+
371+
m_start = self.matrix.shape[dim]
372+
new_matrix_shape = list(self.matrix.shape)
373+
new_matrix_shape[dim] = sum([node_info[1].shape[0] for node_info in node_infos])
374+
self.matrix = torch.cat(
375+
[self.matrix, torch.zeros(new_matrix_shape, dtype=self.matrix.dtype, device=self.matrix.device)],
376+
dim=dim,
377+
)
378+
for node_info in node_infos:
379+
node_length = node_info[1].shape[0]
380+
node = Node(
381+
node_info[0],
382+
node_info[1],
383+
torch.arange(m_start, m_start + node_length, device=self.matrix.device),
384+
)
385+
self._update_node(node, node_dict)
386+
m_start += node_length
387+
388+
def add_source(self, node_infos: list[NodeInfo] | NodeInfo):
389+
node_infos = node_infos if isinstance(node_infos, list) else [node_infos]
390+
self._add_elements(node_infos, 0)
391+
392+
def add_target(self, node_infos: list[NodeInfo] | NodeInfo):
393+
node_infos = node_infos if isinstance(node_infos, list) else [node_infos]
394+
self._add_elements(node_infos, 1)
395+
396+
def update_matrix(self, matrix: torch.Tensor):
397+
self.matrix[:, :] = matrix
398+
399+
@staticmethod
400+
def _update_node(node: Node, node_dict: dict[Any, Node]):
401+
if node.key not in node_dict:
402+
node_dict[node.key] = node
403+
else:
404+
node_dict[node.key].matrix_indices = torch.cat(
405+
[node_dict[node.key].matrix_indices, node.matrix_indices],
406+
dim=0,
407+
)
408+
node_dict[node.key].indices = torch.cat(
409+
[node_dict[node.key].indices, node.indices],
410+
dim=0,
411+
)
412+
413+
def _get_sublines(self, node_infos: list[NodeInfo] | None, dim: int):
414+
full_node_dict = self.source if dim == 0 else self.target
415+
if node_infos is None:
416+
return (full_node_dict, torch.arange(self.matrix.shape[dim], device=self.matrix.device))
417+
418+
new_node_dict = {}
419+
old_matrix_indices = torch.zeros(0, device=self.matrix.device, dtype=torch.long)
420+
for node_info in node_infos:
421+
r = torch.empty(full_node_dict[node_info[0]].indices.max() + 1, device=self.matrix.device, dtype=torch.long)
422+
r[full_node_dict[node_info[0]].indices] = torch.arange(
423+
full_node_dict[node_info[0]].indices.shape[0], device=self.matrix.device
424+
)
425+
matrix_indices = full_node_dict[node_info[0]].matrix_indices[r[node_info[1]]]
426+
old_matrix_indices = torch.cat(
427+
[old_matrix_indices, matrix_indices],
428+
dim=0,
429+
)
430+
self._update_node(
431+
Node(
432+
node_info[0],
433+
node_info[1],
434+
torch.arange(
435+
old_matrix_indices.shape[0] - matrix_indices.shape[0],
436+
old_matrix_indices.shape[0],
437+
device=self.matrix.device,
438+
),
439+
),
440+
new_node_dict,
441+
)
442+
443+
return (new_node_dict, old_matrix_indices)
444+
445+
@classmethod
446+
def _build_submatrix(
447+
cls, matrix: torch.Tensor, source_node_dict: dict[Any, Node], target_node_dict: dict[Any, Node]
448+
) -> Self:
449+
submatrix = cls()
450+
submatrix.matrix = matrix
451+
submatrix.source = source_node_dict
452+
submatrix.target = target_node_dict
453+
return submatrix
454+
455+
def get_submatrix(
456+
self,
457+
source_node_infos: NodeInfo | list[NodeInfo] | None = None,
458+
target_node_infos: NodeInfo | list[NodeInfo] | None = None,
459+
):
460+
source_node_infos = (
461+
[source_node_infos]
462+
if not isinstance(source_node_infos, list) and source_node_infos is not None
463+
else source_node_infos
464+
)
465+
target_node_infos = (
466+
[target_node_infos]
467+
if not isinstance(target_node_infos, list) and target_node_infos is not None
468+
else target_node_infos
469+
)
470+
source_node_dict, source_matrix_indices = self._get_sublines(source_node_infos, 0)
471+
target_node_dict, target_matrix_indices = self._get_sublines(target_node_infos, 1)
472+
submatrix = Matrix._build_submatrix(
473+
self.matrix[source_matrix_indices][:, target_matrix_indices], source_node_dict, target_node_dict
474+
)
475+
return submatrix
476+
477+
351478
class AdjacencyMatrix(torch.Tensor):
352479
matrix: torch.Tensor
353480
source_list: list[tuple[torch.Tensor, Any]]

tests/unit/test_matrix.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
3+
from lm_saes.backend.language_model import Matrix
4+
5+
6+
def _build_sample_matrix() -> Matrix:
7+
matrix = Matrix()
8+
matrix.add_source([("s1", torch.tensor([0, 2])), ("s2", torch.tensor([1]))])
9+
matrix.add_target([("t1", torch.tensor([10, 11])), ("t2", torch.tensor([12]))])
10+
matrix.update_matrix(
11+
torch.tensor(
12+
[
13+
[1.0, 2.0, 3.0],
14+
[4.0, 5.0, 6.0],
15+
[7.0, 8.0, 9.0],
16+
],
17+
dtype=torch.float32,
18+
)
19+
)
20+
return matrix
21+
22+
23+
def test_add_source_and_target_shapes_and_node_mappings():
24+
matrix = Matrix()
25+
matrix.add_source(("src", torch.tensor([3, 7])))
26+
matrix.add_target(("tgt", torch.tensor([11, 13, 17])))
27+
28+
assert matrix.matrix.shape == (2, 3)
29+
assert torch.equal(matrix.source["src"].indices, torch.tensor([3, 7]))
30+
assert torch.equal(matrix.source["src"].matrix_indices, torch.tensor([0, 1]))
31+
assert torch.equal(matrix.target["tgt"].indices, torch.tensor([11, 13, 17]))
32+
assert torch.equal(matrix.target["tgt"].matrix_indices, torch.tensor([0, 1, 2]))
33+
34+
35+
def test_add_source_merges_same_key():
36+
matrix = Matrix()
37+
matrix.add_source(("src", torch.tensor([0, 2])))
38+
matrix.add_source(("src", torch.tensor([5])))
39+
40+
assert matrix.matrix.shape == (3, 0)
41+
assert torch.equal(matrix.source["src"].indices, torch.tensor([0, 2, 5]))
42+
assert torch.equal(matrix.source["src"].matrix_indices, torch.tensor([0, 1, 2]))
43+
44+
45+
def test_update_matrix_overwrites_values():
46+
matrix = Matrix()
47+
matrix.add_source(("src", torch.tensor([0, 1])))
48+
matrix.add_target(("tgt", torch.tensor([0, 1])))
49+
50+
new_values = torch.tensor([[1.5, -2.0], [3.25, 4.0]], dtype=torch.float32)
51+
matrix.update_matrix(new_values)
52+
53+
assert torch.equal(matrix.matrix, new_values)
54+
55+
56+
def test_get_submatrix_selects_expected_rows_and_columns():
57+
matrix = _build_sample_matrix()
58+
59+
submatrix = matrix.get_submatrix(
60+
source_node_infos=[("s1", torch.tensor([2]))],
61+
target_node_infos=[("t1", torch.tensor([11])), ("t2", torch.tensor([12]))],
62+
)
63+
64+
expected = torch.tensor([[5.0, 6.0]], dtype=torch.float32)
65+
assert torch.equal(submatrix.matrix, expected)
66+
assert torch.equal(submatrix.source["s1"].indices, torch.tensor([2]))
67+
assert torch.equal(submatrix.source["s1"].matrix_indices, torch.tensor([0]))
68+
assert torch.equal(submatrix.target["t1"].indices, torch.tensor([11]))
69+
assert torch.equal(submatrix.target["t1"].matrix_indices, torch.tensor([0]))
70+
assert torch.equal(submatrix.target["t2"].indices, torch.tensor([12]))
71+
assert torch.equal(submatrix.target["t2"].matrix_indices, torch.tensor([1]))
72+
73+
74+
def test_add_multiple_nodes_assigns_contiguous_matrix_indices():
75+
matrix = Matrix()
76+
matrix.add_source([("s1", torch.tensor([0, 2])), ("s2", torch.tensor([1]))])
77+
matrix.add_target([("t1", torch.tensor([10, 11])), ("t2", torch.tensor([12]))])
78+
79+
assert torch.equal(matrix.source["s1"].matrix_indices, torch.tensor([0, 1]))
80+
assert torch.equal(matrix.source["s2"].matrix_indices, torch.tensor([2]))
81+
assert torch.equal(matrix.target["t1"].matrix_indices, torch.tensor([0, 1]))
82+
assert torch.equal(matrix.target["t2"].matrix_indices, torch.tensor([2]))

0 commit comments

Comments
 (0)