|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from dataclasses import dataclass |
3 | 4 | import json |
4 | 5 | import os |
5 | 6 | import re |
|
8 | 9 | from contextlib import contextmanager |
9 | 10 | from functools import partial |
10 | 11 | from itertools import accumulate |
11 | | -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast |
| 12 | +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Self, Union, cast |
12 | 13 |
|
13 | 14 | import einops |
14 | 15 | import torch |
@@ -348,6 +349,132 @@ def detach_hook_fn(x: torch.Tensor, hook: HookPoint): |
348 | 349 | return detach_hooks |
349 | 350 |
|
350 | 351 |
|
| 352 | +@dataclass |
| 353 | +class Node: |
| 354 | + key: Any |
| 355 | + indices: torch.Tensor |
| 356 | + matrix_indices: torch.Tensor |
| 357 | + |
| 358 | + |
| 359 | +NodeInfo = tuple[Any, torch.Tensor] |
| 360 | + |
| 361 | + |
| 362 | +class Matrix: |
| 363 | + def __init__(self): |
| 364 | + self.matrix = torch.zeros(0, 0, dtype=torch.float32, device="cpu") |
| 365 | + self.source = {} |
| 366 | + self.target = {} |
| 367 | + |
| 368 | + def _add_elements(self, node_infos: list[NodeInfo], dim: int): |
| 369 | + node_dict = self.source if dim == 0 else self.target |
| 370 | + |
| 371 | + m_start = self.matrix.shape[dim] |
| 372 | + new_matrix_shape = list(self.matrix.shape) |
| 373 | + new_matrix_shape[dim] = sum([node_info[1].shape[0] for node_info in node_infos]) |
| 374 | + self.matrix = torch.cat( |
| 375 | + [self.matrix, torch.zeros(new_matrix_shape, dtype=self.matrix.dtype, device=self.matrix.device)], |
| 376 | + dim=dim, |
| 377 | + ) |
| 378 | + for node_info in node_infos: |
| 379 | + node_length = node_info[1].shape[0] |
| 380 | + node = Node( |
| 381 | + node_info[0], |
| 382 | + node_info[1], |
| 383 | + torch.arange(m_start, m_start + node_length, device=self.matrix.device), |
| 384 | + ) |
| 385 | + self._update_node(node, node_dict) |
| 386 | + m_start += node_length |
| 387 | + |
| 388 | + def add_source(self, node_infos: list[NodeInfo] | NodeInfo): |
| 389 | + node_infos = node_infos if isinstance(node_infos, list) else [node_infos] |
| 390 | + self._add_elements(node_infos, 0) |
| 391 | + |
| 392 | + def add_target(self, node_infos: list[NodeInfo] | NodeInfo): |
| 393 | + node_infos = node_infos if isinstance(node_infos, list) else [node_infos] |
| 394 | + self._add_elements(node_infos, 1) |
| 395 | + |
| 396 | + def update_matrix(self, matrix: torch.Tensor): |
| 397 | + self.matrix[:, :] = matrix |
| 398 | + |
| 399 | + @staticmethod |
| 400 | + def _update_node(node: Node, node_dict: dict[Any, Node]): |
| 401 | + if node.key not in node_dict: |
| 402 | + node_dict[node.key] = node |
| 403 | + else: |
| 404 | + node_dict[node.key].matrix_indices = torch.cat( |
| 405 | + [node_dict[node.key].matrix_indices, node.matrix_indices], |
| 406 | + dim=0, |
| 407 | + ) |
| 408 | + node_dict[node.key].indices = torch.cat( |
| 409 | + [node_dict[node.key].indices, node.indices], |
| 410 | + dim=0, |
| 411 | + ) |
| 412 | + |
| 413 | + def _get_sublines(self, node_infos: list[NodeInfo] | None, dim: int): |
| 414 | + full_node_dict = self.source if dim == 0 else self.target |
| 415 | + if node_infos is None: |
| 416 | + return (full_node_dict, torch.arange(self.matrix.shape[dim], device=self.matrix.device)) |
| 417 | + |
| 418 | + new_node_dict = {} |
| 419 | + old_matrix_indices = torch.zeros(0, device=self.matrix.device, dtype=torch.long) |
| 420 | + for node_info in node_infos: |
| 421 | + r = torch.empty(full_node_dict[node_info[0]].indices.max() + 1, device=self.matrix.device, dtype=torch.long) |
| 422 | + r[full_node_dict[node_info[0]].indices] = torch.arange( |
| 423 | + full_node_dict[node_info[0]].indices.shape[0], device=self.matrix.device |
| 424 | + ) |
| 425 | + matrix_indices = full_node_dict[node_info[0]].matrix_indices[r[node_info[1]]] |
| 426 | + old_matrix_indices = torch.cat( |
| 427 | + [old_matrix_indices, matrix_indices], |
| 428 | + dim=0, |
| 429 | + ) |
| 430 | + self._update_node( |
| 431 | + Node( |
| 432 | + node_info[0], |
| 433 | + node_info[1], |
| 434 | + torch.arange( |
| 435 | + old_matrix_indices.shape[0] - matrix_indices.shape[0], |
| 436 | + old_matrix_indices.shape[0], |
| 437 | + device=self.matrix.device, |
| 438 | + ), |
| 439 | + ), |
| 440 | + new_node_dict, |
| 441 | + ) |
| 442 | + |
| 443 | + return (new_node_dict, old_matrix_indices) |
| 444 | + |
| 445 | + @classmethod |
| 446 | + def _build_submatrix( |
| 447 | + cls, matrix: torch.Tensor, source_node_dict: dict[Any, Node], target_node_dict: dict[Any, Node] |
| 448 | + ) -> Self: |
| 449 | + submatrix = cls() |
| 450 | + submatrix.matrix = matrix |
| 451 | + submatrix.source = source_node_dict |
| 452 | + submatrix.target = target_node_dict |
| 453 | + return submatrix |
| 454 | + |
| 455 | + def get_submatrix( |
| 456 | + self, |
| 457 | + source_node_infos: NodeInfo | list[NodeInfo] | None = None, |
| 458 | + target_node_infos: NodeInfo | list[NodeInfo] | None = None, |
| 459 | + ): |
| 460 | + source_node_infos = ( |
| 461 | + [source_node_infos] |
| 462 | + if not isinstance(source_node_infos, list) and source_node_infos is not None |
| 463 | + else source_node_infos |
| 464 | + ) |
| 465 | + target_node_infos = ( |
| 466 | + [target_node_infos] |
| 467 | + if not isinstance(target_node_infos, list) and target_node_infos is not None |
| 468 | + else target_node_infos |
| 469 | + ) |
| 470 | + source_node_dict, source_matrix_indices = self._get_sublines(source_node_infos, 0) |
| 471 | + target_node_dict, target_matrix_indices = self._get_sublines(target_node_infos, 1) |
| 472 | + submatrix = Matrix._build_submatrix( |
| 473 | + self.matrix[source_matrix_indices][:, target_matrix_indices], source_node_dict, target_node_dict |
| 474 | + ) |
| 475 | + return submatrix |
| 476 | + |
| 477 | + |
351 | 478 | class AdjacencyMatrix(torch.Tensor): |
352 | 479 | matrix: torch.Tensor |
353 | 480 | source_list: list[tuple[torch.Tensor, Any]] |
|
0 commit comments