"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from typing_extensions import Literal\n",
+ "\n",
+ "\n",
+ "def stack_head_pattern_from_cache(\n",
+ " cache,\n",
+ ") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n",
+ " \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n",
+ " stacked_head_pattern = torch.stack(\n",
+ " [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n",
+ " )\n",
+ " stacked_head_pattern = einops.rearrange(\n",
+ " stacked_head_pattern,\n",
+ " \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n",
+ " )\n",
+ " return stacked_head_pattern\n",
+ "\n",
+ "\n",
+ "def attr_patch_head_pattern(\n",
+ " clean_cache: ActivationCache,\n",
+ " corrupted_cache: ActivationCache,\n",
+ " corrupted_grad_cache: ActivationCache,\n",
+ ") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n",
+ " labels = HEAD_NAMES\n",
+ "\n",
+ " clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n",
+ " corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n",
+ " corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n",
+ " head_pattern_attr = einops.reduce(\n",
+ " corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n",
+ " \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " return head_pattern_attr, labels\n",
+ "\n",
+ "\n",
+ "head_pattern_attr, labels = attr_patch_head_pattern(\n",
+ " clean_cache, corrupted_cache, corrupted_grad_cache\n",
+ ")\n",
+ "\n",
+ "plot_attention_attr(\n",
+ " einops.rearrange(\n",
+ " head_pattern_attr,\n",
+ " \"(layer head) dest src -> layer head dest src\",\n",
+ " layer=model.cfg.n_layers,\n",
+ " head=model.cfg.n_heads,\n",
+ " ),\n",
+ " clean_tokens,\n",
+ " index=0,\n",
+ " title=\"Head Pattern Attribution Patching\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def get_head_vector_grad_input_from_grad_cache(\n",
+ " grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",
+ ") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n",
+ " vector_grad = grad_cache[activation_name, layer]\n",
+ " ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n",
+ " attn_layer_object = model.blocks[layer].attn\n",
+ " if activation_name == \"q\":\n",
+ " W = attn_layer_object.W_Q\n",
+ " elif activation_name == \"k\":\n",
+ " W = attn_layer_object.W_K\n",
+ " elif activation_name == \"v\":\n",
+ " W = attn_layer_object.W_V\n",
+ " else:\n",
+ " raise ValueError(\"Invalid activation name\")\n",
+ "\n",
+ " return einsum(\n",
+ " \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n",
+ " vector_grad,\n",
+ " ln_scales.squeeze(-1),\n",
+ " W,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def get_stacked_head_vector_grad_input(\n",
+ " grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",
+ ") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n",
+ " return torch.stack(\n",
+ " [\n",
+ " get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n",
+ " for l in range(model.cfg.n_layers)\n",
+ " ],\n",
+ " dim=0,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def get_full_vector_grad_input(\n",
+ " grad_cache,\n",
+ ") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n",
+ " return torch.stack(\n",
+ " [\n",
+ " get_stacked_head_vector_grad_input(grad_cache, activation_name)\n",
+ " for activation_name in [\"q\", \"k\", \"v\"]\n",
+ " ],\n",
+ " dim=0,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def attr_patch_head_path(\n",
+ " clean_cache: ActivationCache,\n",
+ " corrupted_cache: ActivationCache,\n",
+ " corrupted_grad_cache: ActivationCache,\n",
+ ") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n",
+ " \"\"\"\n",
+ " Computes the attribution patch along the path between each pair of heads.\n",
+ "\n",
+ " Sets this to zero for the path from any late head to any early head\n",
+ "\n",
+ " \"\"\"\n",
+ " start_labels = HEAD_NAMES\n",
+ " end_labels = HEAD_NAMES_QKV\n",
+ " full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n",
+ " clean_head_result_stack = clean_cache.stack_head_results(-1)\n",
+ " corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n",
+ " diff_head_result = einops.rearrange(\n",
+ " clean_head_result_stack - corrupted_head_result_stack,\n",
+ " \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n",
+ " layer=model.cfg.n_layers,\n",
+ " head_index=model.cfg.n_heads,\n",
+ " )\n",
+ " path_attr = einsum(\n",
+ " \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n",
+ " full_vector_grad_input,\n",
+ " diff_head_result,\n",
+ " )\n",
+ " correct_layer_order_mask = (\n",
+ " torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n",
+ " > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n",
+ " ).to(path_attr.device)\n",
+ " zero = torch.zeros(1, device=path_attr.device)\n",
+ " path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n",
+ "\n",
+ " path_attr = einops.rearrange(\n",
+ " path_attr,\n",
+ " \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n",
+ " )\n",
+ " return path_attr, end_labels, start_labels\n",
+ "\n",
+ "\n",
+ "head_path_attr, end_labels, start_labels = attr_patch_head_path(\n",
+ " clean_cache, corrupted_cache, corrupted_grad_cache\n",
+ ")\n",
+ "imshow(\n",
+ " head_path_attr.sum(-1),\n",
+ " y=end_labels,\n",
+ " yaxis=\"Path End (Head Input)\",\n",
+ " x=start_labels,\n",
+ " xaxis=\"Path Start (Head Output)\",\n",
+ " title=\"Head Path Attribution Patching\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n",
+ "line(head_out_values)\n",
+ "top_head_indices = head_out_indices[:22].sort().values\n",
+ "top_end_indices = []\n",
+ "top_end_labels = []\n",
+ "top_start_indices = []\n",
+ "top_start_labels = []\n",
+ "for i in top_head_indices:\n",
+ " i = i.item()\n",
+ " top_start_indices.append(i)\n",
+ " top_start_labels.append(start_labels[i])\n",
+ " for j in range(3):\n",
+ " top_end_indices.append(3 * i + j)\n",
+ " top_end_labels.append(end_labels[3 * i + j])\n",
+ "\n",
+ "imshow(\n",
+ " head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n",
+ " y=top_end_labels,\n",
+ " yaxis=\"Path End (Head Input)\",\n",
+ " x=top_start_labels,\n",
+ " xaxis=\"Path Start (Head Output)\",\n",
+ " title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n",
+ " imshow(\n",
+ " head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n",
+ " y=top_end_labels[j::3],\n",
+ " yaxis=\"Path End (Head Input)\",\n",
+ " x=top_start_labels,\n",
+ " xaxis=\"Path Start (Head Output)\",\n",
+ " title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "top_head_path_attr = einops.rearrange(\n",
+ " head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n",
+ " \"(head_end qkv) head_start -> qkv head_end head_start\",\n",
+ " qkv=3,\n",
+ ")\n",
+ "imshow(\n",
+ " top_head_path_attr,\n",
+ " y=[i[:-1] for i in top_end_labels[::3]],\n",
+ " yaxis=\"Path End (Head Input)\",\n",
+ " x=top_start_labels,\n",
+ " xaxis=\"Path Start (Head Output)\",\n",
+ " title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Query\", \"Key\", \"Value\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "interesting_heads = [\n",
+ " 5 * model.cfg.n_heads + 5,\n",
+ " 8 * model.cfg.n_heads + 6,\n",
+ " 9 * model.cfg.n_heads + 9,\n",
+ "]\n",
+ "interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n",
+ "for head_index, label in zip(interesting_heads, interesting_head_labels):\n",
+ " in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n",
+ " out_paths = head_path_attr[:, head_index].sum(-1)\n",
+ " out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n",
+ " all_paths = torch.cat([in_paths, out_paths], dim=0)\n",
+ " all_paths = einops.rearrange(\n",
+ " all_paths,\n",
+ " \"path_type (layer head) -> path_type layer head\",\n",
+ " layer=model.cfg.n_layers,\n",
+ " head=model.cfg.n_heads,\n",
+ " )\n",
+ " imshow(\n",
+ " all_paths,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\n",
+ " \"Query (In)\",\n",
+ " \"Key (In)\",\n",
+ " \"Value (In)\",\n",
+ " \"Query (Out)\",\n",
+ " \"Key (Out)\",\n",
+ " \"Value (Out)\",\n",
+ " ],\n",
+ " title=f\"Input and Output Paths for head {label}\",\n",
+ " yaxis=\"Layer\",\n",
+ " xaxis=\"Head\",\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Validating Attribution vs Activation Patching\n",
+ " Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n",
+ " My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n",
+ " See more discussion in the accompanying blog post!\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "attribution_cache_dict = {}\n",
+ "for key in corrupted_grad_cache.cache_dict.keys():\n",
+ " attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n",
+ " clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n",
+ " )\n",
+ "attr_cache = ActivationCache(attribution_cache_dict, model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " By block: For each head we patch the starting residual stream, attention output + MLP output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "str_tokens = model.to_str_tokens(clean_tokens[0])\n",
+ "context_length = len(str_tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "95a5290e11b64b6a95ef5dd37d027c7a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/180 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "be204ae96db74023b957e592a9a0fde9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/180 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a2409bc6d2524634a48f4556a6773415",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/180 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "every_block_act_patch_result = patching.get_act_patch_block_every(\n",
+ " model, corrupted_tokens, clean_cache, ioi_metric\n",
+ ")\n",
+ "imshow(\n",
+ " every_block_act_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n",
+ " title=\"Activation Patching Per Block\",\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def get_attr_patch_block_every(attr_cache):\n",
+ " resid_pre_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"resid_pre\"),\n",
+ " \"layer batch pos d_model -> layer pos\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " attn_out_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"attn_out\"),\n",
+ " \"layer batch pos d_model -> layer pos\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " mlp_out_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"mlp_out\"),\n",
+ " \"layer batch pos d_model -> layer pos\",\n",
+ " \"sum\",\n",
+ " )\n",
+ "\n",
+ " every_block_attr_patch_result = torch.stack(\n",
+ " [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n",
+ " )\n",
+ " return every_block_attr_patch_result\n",
+ "\n",
+ "\n",
+ "every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n",
+ "imshow(\n",
+ " every_block_attr_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n",
+ " title=\"Attribution Patching Per Block\",\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "scatter(\n",
+ " y=every_block_attr_patch_result.reshape(3, -1),\n",
+ " x=every_block_act_patch_result.reshape(3, -1),\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n",
+ " title=\"Attribution vs Activation Patching Per Block\",\n",
+ " xaxis=\"Activation Patch\",\n",
+ " yaxis=\"Attribution Patch\",\n",
+ " hover=[\n",
+ " f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n",
+ " for l in range(model.cfg.n_layers)\n",
+ " for p in range(context_length)\n",
+ " ],\n",
+ " color=einops.repeat(\n",
+ " torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n",
+ " ),\n",
+ " color_continuous_scale=\"Portland\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "18b2e6b0985b40cd8c0cd1a16ba62975",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2d034be6501e4c9db1c290b1705e60f8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e2f3a429be1745e9a874d2fd4881841d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f8e5bf04563c4b0da801f3f5e1b08e7e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5ae4c563073843a68df3b590cb8b4dc3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n",
+ " model, corrupted_tokens, clean_cache, ioi_metric\n",
+ ")\n",
+ "imshow(\n",
+ " every_head_all_pos_act_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Activation Patching Per Head (All Pos)\",\n",
+ " xaxis=\"Head\",\n",
+ " yaxis=\"Layer\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def get_attr_patch_attn_head_all_pos_every(attr_cache):\n",
+ " head_out_all_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"z\"),\n",
+ " \"layer batch pos head_index d_head -> layer head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_q_all_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"q\"),\n",
+ " \"layer batch pos head_index d_head -> layer head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_k_all_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"k\"),\n",
+ " \"layer batch pos head_index d_head -> layer head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_v_all_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"v\"),\n",
+ " \"layer batch pos head_index d_head -> layer head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_pattern_all_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"pattern\"),\n",
+ " \"layer batch head_index dest_pos src_pos -> layer head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ "\n",
+ " return torch.stack(\n",
+ " [\n",
+ " head_out_all_pos_attr,\n",
+ " head_q_all_pos_attr,\n",
+ " head_k_all_pos_attr,\n",
+ " head_v_all_pos_attr,\n",
+ " head_pattern_all_pos_attr,\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "\n",
+ "every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n",
+ " attr_cache\n",
+ ")\n",
+ "imshow(\n",
+ " every_head_all_pos_attr_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Attribution Patching Per Head (All Pos)\",\n",
+ " xaxis=\"Head\",\n",
+ " yaxis=\"Layer\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "scatter(\n",
+ " y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n",
+ " x=every_head_all_pos_act_patch_result.reshape(5, -1),\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n",
+ " xaxis=\"Activation Patch\",\n",
+ " yaxis=\"Attribution Patch\",\n",
+ " include_diag=True,\n",
+ " hover=head_out_labels,\n",
+ " color=einops.repeat(\n",
+ " torch.arange(model.cfg.n_layers),\n",
+ " \"layer -> (layer head)\",\n",
+ " head=model.cfg.n_heads,\n",
+ " ),\n",
+ " color_continuous_scale=\"Portland\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n",
+ " Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "graph_tok_labels = [\n",
+ " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n",
+ "]\n",
+ "imshow(\n",
+ " clean_cache[\"pattern\", 5][:, 5],\n",
+ " x=graph_tok_labels,\n",
+ " y=graph_tok_labels,\n",
+ " facet_col=0,\n",
+ " title=\"Attention for Head L5H5\",\n",
+ " facet_name=\"Prompt\",\n",
+ ")\n",
+ "imshow(\n",
+ " clean_cache[\"pattern\", 10][:, 7],\n",
+ " x=graph_tok_labels,\n",
+ " y=graph_tok_labels,\n",
+ " facet_col=0,\n",
+ " title=\"Attention for Head L10H7\",\n",
+ " facet_name=\"Prompt\",\n",
+ ")\n",
+ "imshow(\n",
+ " clean_cache[\"pattern\", 11][:, 10],\n",
+ " x=graph_tok_labels,\n",
+ " y=graph_tok_labels,\n",
+ " facet_col=0,\n",
+ " title=\"Attention for Head L11H10\",\n",
+ " facet_name=\"Prompt\",\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# [markdown]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "06f39489001845849fbc7446a07066f4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2160 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1c2eba74a11f47d0a78dd78bd0e60b84",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2160 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f92f8c8c2ffa4d889def1b4214b6ec04",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2160 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "65d0fd01f6dc40409c61f5fde0e30470",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2160 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "52452e90576545f8b12a1bbad5fc7c08",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2160 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n",
+ " model, corrupted_tokens, clean_cache, ioi_metric\n",
+ ")\n",
+ "every_head_by_pos_act_patch_result = einops.rearrange(\n",
+ " every_head_by_pos_act_patch_result,\n",
+ " \"act_type layer pos head -> act_type (layer head) pos\",\n",
+ ")\n",
+ "imshow(\n",
+ " every_head_by_pos_act_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Activation Patching Per Head (By Pos)\",\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer & Head\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",
+ " y=head_out_labels,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def get_attr_patch_attn_head_by_pos_every(attr_cache):\n",
+ " head_out_by_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"z\"),\n",
+ " \"layer batch pos head_index d_head -> layer pos head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_q_by_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"q\"),\n",
+ " \"layer batch pos head_index d_head -> layer pos head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_k_by_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"k\"),\n",
+ " \"layer batch pos head_index d_head -> layer pos head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_v_by_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"v\"),\n",
+ " \"layer batch pos head_index d_head -> layer pos head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ " head_pattern_by_pos_attr = einops.reduce(\n",
+ " attr_cache.stack_activation(\"pattern\"),\n",
+ " \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n",
+ " \"sum\",\n",
+ " )\n",
+ "\n",
+ " return torch.stack(\n",
+ " [\n",
+ " head_out_by_pos_attr,\n",
+ " head_q_by_pos_attr,\n",
+ " head_k_by_pos_attr,\n",
+ " head_v_by_pos_attr,\n",
+ " head_pattern_by_pos_attr,\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "\n",
+ "every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n",
+ "every_head_by_pos_attr_patch_result = einops.rearrange(\n",
+ " every_head_by_pos_attr_patch_result,\n",
+ " \"act_type layer pos head -> act_type (layer head) pos\",\n",
+ ")\n",
+ "imshow(\n",
+ " every_head_by_pos_attr_patch_result,\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Attribution Patching Per Head (By Pos)\",\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer & Head\",\n",
+ " zmax=1,\n",
+ " zmin=-1,\n",
+ " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",
+ " y=head_out_labels,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "scatter(\n",
+ " y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n",
+ " x=every_head_by_pos_act_patch_result.reshape(5, -1),\n",
+ " facet_col=0,\n",
+ " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n",
+ " title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n",
+ " xaxis=\"Activation Patch\",\n",
+ " yaxis=\"Attribution Patch\",\n",
+ " include_diag=True,\n",
+ " hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n",
+ " color=einops.repeat(\n",
+ " torch.arange(model.cfg.n_layers),\n",
+ " \"layer -> (layer head pos)\",\n",
+ " head=model.cfg.n_heads,\n",
+ " pos=15,\n",
+ " ),\n",
+ " color_continuous_scale=\"Portland\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Factual Knowledge Patching Example\n",
+ " Incomplete, but maybe of interest!\n",
+ " Note that I have better results with the corrupted prompt as having random words rather than Colosseum."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model gpt2-xl into HookedTransformer\n",
+ "Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n",
+ "Tokenized answer: [' Paris']\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Performance on answer token:\n",
+ "Rank: 0 Logit: 20.73 Prob: 95.80% Token: | Paris|\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "Performance on answer token:\n",
+ "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n",
+ "Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n",
+ "Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n",
+ "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n",
+ "Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n",
+ "Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n",
+ "Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n",
+ "Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n",
+ "Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n",
+ "Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Ranks of the answer tokens: [(' Paris', 0)]\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n",
+ "Tokenized answer: [' Rome']\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Performance on answer token:\n",
+ "Rank: 0 Logit: 20.02 Prob: 83.70% Token: | Rome|\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "Performance on answer token:\n",
+ "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n",
+ "Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n",
+ "Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n",
+ "Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n",
+ "Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n",
+ "Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n",
+ "Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n",
+ "Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n",
+ "Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n",
+ "Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Ranks of the answer tokens: [(' Rome', 0)]\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "gpt2_xl = TransformerBridge.boot_transformers(\"gpt2-xl\")\n",
+ "gpt2_xl.enable_compatibility_mode()\n",
+ "clean_prompt = \"The Eiffel Tower is located in the city of\"\n",
+ "clean_answer = \" Paris\"\n",
+ "# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n",
+ "corrupted_prompt = \"The Colosseum is located in the city of\"\n",
+ "corrupted_answer = \" Rome\"\n",
+ "utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n",
+ "utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n",
+ "corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n",
+ "\n",
+ "\n",
+ "def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n",
+ " return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Clean logit diff: 10.634519577026367\n",
+ "Corrupted logit diff: -8.988396644592285\n",
+ "Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n",
+ "Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n",
+ "CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n",
+ "corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n",
+ "CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n",
+ "\n",
+ "\n",
+ "def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n",
+ " return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n",
+ " CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n",
+ " )\n",
+ "\n",
+ "\n",
+ "print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n",
+ "print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n",
+ "print(\"Clean Metric:\", factual_metric(clean_logits))\n",
+ "print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n",
+ "Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"
+ ]
+ }
+ ],
+ "source": [
+ "clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n",
+ "clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n",
+ "corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n",
+ "corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n",
+ "print(\"Clean:\", clean_str_tokens)\n",
+ "print(\"Corrupted:\", corrupted_str_tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b767eef7a3cd49b9b3cb6e5301463f08",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/48 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def act_patch_residual(clean_cache, corrupted_tokens, model: TransformerBridge, metric):\n",
+ " if len(corrupted_tokens.shape) == 2:\n",
+ " corrupted_tokens = corrupted_tokens[0]\n",
+ " residual_patches = torch.zeros(\n",
+ " (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n",
+ " )\n",
+ "\n",
+ " def residual_hook(resid_pre, hook, layer, pos):\n",
+ " resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n",
+ " return resid_pre\n",
+ "\n",
+ " for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n",
+ " for pos in range(len(corrupted_tokens)):\n",
+ " patched_logits = model.run_with_hooks(\n",
+ " corrupted_tokens,\n",
+ " fwd_hooks=[\n",
+ " (\n",
+ " f\"blocks.{layer}.hook_resid_pre\",\n",
+ " partial(residual_hook, layer=layer, pos=pos),\n",
+ " )\n",
+ " ],\n",
+ " )\n",
+ " residual_patches[layer, pos] = metric(patched_logits).item()\n",
+ " return residual_patches\n",
+ "\n",
+ "\n",
+ "residual_act_patch = act_patch_residual(\n",
+ " clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",
+ ")\n",
+ "\n",
+ "imshow(\n",
+ " residual_act_patch,\n",
+ " title=\"Factual Recall Patching (Residual)\",\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer\",\n",
+ " x=clean_str_tokens,\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}