Skip to content

Commit f9a238c

Browse files
committed
refactor(circuit): enhance node mapping and tracing result handling
Updated the map_idx_to_nodes method to include a target_layer argument for improved node mapping. Modified the extract_QK_tracing_result function to pass the layer parameter, ensuring accurate tracing results. Additionally, added a stringify_nodes method to the QKTracingResults class for better handling of node representations. Removed debug print statements in create_circuit function to clean up the code.
1 parent 66c288d commit f9a238c

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

server/routers/circuits.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ def concretize_graph_data(graph_data: dict[str, Any]):
139139
def create_circuit(sae_set_name: str, request: GenerateCircuitRequest):
140140
"""Generate and save a circuit graph for a given prompt and SAE set."""
141141

142-
print(request.model_dump())
143-
144142
sae_set = client.get_sae_set(name=sae_set_name)
145143
assert sae_set is not None, f"SAE set {sae_set_name} not found"
146144
sae_names = sae_set.sae_names

src/lm_saes/circuit/utils/attn_scores_attribution.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def drop_bias_terms(self):
6969
), "We might not want to drop bias terms for the second time."
7070
return self
7171

72-
def map_idx_to_nodes(self, indices: torch.Tensor, input_ids: torch.Tensor) -> list[Node]:
72+
def map_idx_to_nodes(self, indices: torch.Tensor, target_layer: int, input_ids: torch.Tensor) -> list[Node]:
7373
"""Map component indices to node names.
7474
7575
Args:
@@ -147,7 +147,7 @@ def map_idx_to_nodes(self, indices: torch.Tensor, input_ids: torch.Tensor) -> li
147147
)
148148
else:
149149
res = Node.bias_node(
150-
int(layer),
150+
target_layer.item(),
151151
self.pos,
152152
bias_name=self.bias_names[idx - decoder_bias_idx],
153153
is_from_qk_tracing=True,
@@ -378,6 +378,7 @@ def get_single_side_QK_components(
378378
def extract_QK_tracing_result(
379379
q_side: ResStreamComponents,
380380
k_side: ResStreamComponents,
381+
layer: int,
381382
input_ids: torch.Tensor,
382383
topk: int = 10,
383384
) -> QKTracingResults:
@@ -402,8 +403,8 @@ def extract_QK_tracing_result(
402403
# pair-wise top contributors
403404
pair_wise_contributors = list(
404405
zip(
405-
q_side.map_idx_to_nodes(q_features, input_ids),
406-
k_side.map_idx_to_nodes(k_features, input_ids),
406+
q_side.map_idx_to_nodes(q_features, layer, input_ids),
407+
k_side.map_idx_to_nodes(k_features, layer, input_ids),
407408
topk_pairwise_attr_entries.values.cpu().tolist(),
408409
)
409410
)
@@ -413,13 +414,13 @@ def extract_QK_tracing_result(
413414
)
414415
top_q_marginal_contributors = list(
415416
zip(
416-
q_side.map_idx_to_nodes(top_q_marginal_contributors.indices, input_ids),
417+
q_side.map_idx_to_nodes(top_q_marginal_contributors.indices, layer, input_ids),
417418
top_q_marginal_contributors.values.cpu().tolist(),
418419
)
419420
)
420421
top_k_marginal_contributors = list(
421422
zip(
422-
k_side.map_idx_to_nodes(top_k_marginal_contributors.indices, input_ids),
423+
k_side.map_idx_to_nodes(top_k_marginal_contributors.indices, layer, input_ids),
423424
top_k_marginal_contributors.values.cpu().tolist(),
424425
)
425426
)
@@ -501,6 +502,6 @@ def compute_attn_scores_attribution(
501502
q_side=False,
502503
).drop_bias_terms()
503504

504-
result = extract_QK_tracing_result(q_side, k_side, input_ids, topk=topk)
505+
result = extract_QK_tracing_result(q_side, k_side, layer, input_ids, topk=topk)
505506
score_attribution_cache[cache_key] = result
506507
return result

src/lm_saes/circuit/utils/create_graph_files.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,18 @@ def append_qk_tracing_results(graph: Graph, used_nodes: List[Node], clt_names, l
163163
for node in used_nodes:
164164
if node.qk_tracing_results is not None:
165165
from_qk_tracing_nodes.update(node.qk_tracing_results.get_nodes())
166+
166167
nodes_to_add = from_qk_tracing_nodes - existing_nodes
167168
for node in nodes_to_add:
168169
if node.feature_type == "lorsa":
169170
node.sae_name = lorsa_names[node.layer // 2]
170171
elif node.feature_type == "cross layer transcoder":
171172
node.sae_name = clt_names[node.layer // 2]
172173
used_nodes.append(node)
174+
175+
for node in used_nodes:
176+
if node.qk_tracing_results is not None:
177+
node.qk_tracing_results.stringify_nodes()
173178
return used_nodes
174179

175180

src/lm_saes/circuit/utils/graph_file_utils.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def bias_node(cls, layer, pos, bias_name, influence=None, is_from_qk_tracing=Fal
9494
node_id=f"{layer}_{bias_name}_{pos}",
9595
layer=layer,
9696
ctx_idx=pos,
97-
feature_type=bias_name,
97+
feature_type="bias",
9898
influence=influence,
9999
is_from_qk_tracing=is_from_qk_tracing,
100100
)
@@ -154,9 +154,11 @@ class QKTracingResults:
154154
top_k_marginal_contributors: List of tuples containing (k_node, score) for top K marginal contributors.
155155
"""
156156

157-
pair_wise_contributors: list[tuple[Node, Node, float]]
158-
top_q_marginal_contributors: list[tuple[Node, float]]
159-
top_k_marginal_contributors: list[tuple[Node, float]]
157+
NodeType = Node | str
158+
159+
pair_wise_contributors: list[tuple[NodeType, NodeType, float]]
160+
top_q_marginal_contributors: list[tuple[NodeType, float]]
161+
top_k_marginal_contributors: list[tuple[NodeType, float]]
160162

161163
def get_nodes(self) -> set[Node]:
162164
all_relevant_nodes = set()
@@ -168,3 +170,30 @@ def get_nodes(self) -> set[Node]:
168170
for k_node, _ in self.top_k_marginal_contributors:
169171
all_relevant_nodes.add(k_node)
170172
return all_relevant_nodes
173+
174+
def stringify_nodes(self):
175+
assert all(
176+
isinstance(q_node, Node) and isinstance(k_node, Node) for q_node, k_node, _ in self.pair_wise_contributors
177+
) or all(
178+
isinstance(q_node, str) and isinstance(k_node, str) for q_node, k_node, _ in self.pair_wise_contributors
179+
)
180+
assert all(isinstance(q_node, Node) for q_node, _ in self.top_q_marginal_contributors) or all(
181+
isinstance(q_node, str) for q_node, _ in self.top_q_marginal_contributors
182+
)
183+
assert all(isinstance(k_node, Node) for k_node, _ in self.top_k_marginal_contributors) or all(
184+
isinstance(k_node, str) for k_node, _ in self.top_k_marginal_contributors
185+
)
186+
187+
if len(self.pair_wise_contributors) > 0 and isinstance(self.pair_wise_contributors[0][0], str):
188+
return self
189+
190+
self.pair_wise_contributors = [
191+
(q_node.node_id, k_node.node_id, score) for q_node, k_node, score in self.pair_wise_contributors
192+
]
193+
self.top_q_marginal_contributors = [
194+
(q_node.node_id, score) for q_node, score in self.top_q_marginal_contributors
195+
]
196+
self.top_k_marginal_contributors = [
197+
(k_node.node_id, score) for k_node, score in self.top_k_marginal_contributors
198+
]
199+
return self

0 commit comments

Comments
 (0)