|
20 | 20 | "\n", |
21 | 21 | "There's a surprisingly rich ecosystem of easy ways to create interactive graphics, especially for ML systems. If you're trying to do mechanistic interpretability, the ability to do web dev and to both visualize data and interact with it seems high value! \n", |
22 | 22 | "\n", |
23 | | - "This is a demo of how you can combine HookedTransformer and [Gradio](https://gradio.app/) to create an interactive Neuroscope - a visualization of a neuron's activations on text that will dynamically update as you edit the text. I don't particularly claim that this code is any *good*, but the goal is to illustrate what quickly hacking together a custom visualisation (while knowing fuck all about web dev, like me) can look like! (And as such, I try to explain the basic web dev concepts I use)\n", |
| 23 | + "This is a demo of how you can combine TransformerBridge and [Gradio](https://gradio.app/) to create an interactive Neuroscope - a visualization of a neuron's activations on text that will dynamically update as you edit the text. I don't particularly claim that this code is any *good*, but the goal is to illustrate what quickly hacking together a custom visualisation (while knowing fuck all about web dev, like me) can look like! (And as such, I try to explain the basic web dev concepts I use)\n", |
24 | 24 | "\n", |
25 | 25 | "Note that you'll need to run the code yourself to get the interactive interface, so the cell at the bottom will be blank at first!\n", |
26 | 26 | "\n", |
|
36 | 36 | }, |
37 | 37 | { |
38 | 38 | "cell_type": "code", |
39 | | - "execution_count": 1, |
| 39 | + "execution_count": null, |
40 | 40 | "metadata": {}, |
41 | | - "outputs": [ |
42 | | - { |
43 | | - "name": "stdout", |
44 | | - "output_type": "stream", |
45 | | - "text": [ |
46 | | - "Running as a Jupyter notebook - intended for development only!\n" |
47 | | - ] |
48 | | - }, |
49 | | - { |
50 | | - "name": "stderr", |
51 | | - "output_type": "stream", |
52 | | - "text": [ |
53 | | - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_63049/1105475986.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
54 | | - " ipython.magic(\"load_ext autoreload\")\n", |
55 | | - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_63049/1105475986.py:20: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
56 | | - " ipython.magic(\"autoreload 2\")\n" |
57 | | - ] |
58 | | - } |
59 | | - ], |
| 41 | + "outputs": [], |
60 | 42 | "source": [ |
61 | 43 | "# NBVAL_IGNORE_OUTPUT\n", |
62 | 44 | "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", |
|
87 | 69 | }, |
88 | 70 | { |
89 | 71 | "cell_type": "code", |
90 | | - "execution_count": 2, |
| 72 | + "execution_count": 4, |
91 | 73 | "metadata": {}, |
92 | 74 | "outputs": [], |
93 | 75 | "source": [ |
94 | 76 | "import gradio as gr\n", |
95 | | - "from transformer_lens import HookedTransformer\n", |
| 77 | + "from transformer_lens.model_bridge import TransformerBridge\n", |
96 | 78 | "from transformer_lens.utils import to_numpy\n", |
97 | 79 | "from IPython.display import HTML" |
98 | 80 | ] |
|
103 | 85 | "source": [ |
104 | 86 | "## Extracting Model Activations\n", |
105 | 87 | "\n", |
106 | | - "We first write some code using HookedTransformer's cache to extract the neuron activations on a given layer and neuron, for a given text" |
| 88 | + "We first write some code using TransformerBridge's cache to extract the neuron activations on a given layer and neuron, for a given text" |
107 | 89 | ] |
108 | 90 | }, |
109 | 91 | { |
110 | 92 | "cell_type": "code", |
111 | | - "execution_count": 12, |
| 93 | + "execution_count": 5, |
112 | 94 | "metadata": {}, |
113 | 95 | "outputs": [ |
114 | 96 | { |
115 | | - "name": "stdout", |
| 97 | + "name": "stderr", |
116 | 98 | "output_type": "stream", |
117 | 99 | "text": [ |
118 | | - "Loaded pretrained model gpt2-small into HookedTransformer\n" |
| 100 | + "The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n", |
| 101 | + "The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n", |
| 102 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:592: FutureWarning: Hook 'W_V' is deprecated and will be removed in a future version. Use 'v.weight' instead.\n", |
| 103 | + " w_v = block.attn.W_V\n", |
| 104 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:579: FutureWarning: Hook 'W_Q' is deprecated and will be removed in a future version. Use 'q.weight' instead.\n", |
| 105 | + " w_q = block.attn.W_Q\n", |
| 106 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:566: FutureWarning: Hook 'W_K' is deprecated and will be removed in a future version. Use 'k.weight' instead.\n", |
| 107 | + " w_k = block.attn.W_K\n", |
| 108 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:605: FutureWarning: Hook 'W_O' is deprecated and will be removed in a future version. Use 'o.weight' instead.\n", |
| 109 | + " w_o = block.attn.W_O\n", |
| 110 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:616: FutureWarning: Hook 'W_in' is deprecated and will be removed in a future version. Use 'in.weight' instead.\n", |
| 111 | + " return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)\n", |
| 112 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:632: FutureWarning: Hook 'W_out' is deprecated and will be removed in a future version. Use 'out.weight' instead.\n", |
| 113 | + " return torch.stack([block.mlp.W_out for block in self.blocks], dim=0)\n", |
| 114 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:637: FutureWarning: Hook 'b_K' is deprecated and will be removed in a future version. Use 'k.bias' instead.\n", |
| 115 | + " return torch.stack([block.attn.b_K for block in self.blocks], dim=0)\n", |
| 116 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:652: FutureWarning: Hook 'b_O' is deprecated and will be removed in a future version. Use 'o.bias' instead.\n", |
| 117 | + " return torch.stack([block.attn.b_O for block in self.blocks], dim=0)\n", |
| 118 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:642: FutureWarning: Hook 'b_Q' is deprecated and will be removed in a future version. Use 'q.bias' instead.\n", |
| 119 | + " return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)\n", |
| 120 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:647: FutureWarning: Hook 'b_V' is deprecated and will be removed in a future version. Use 'v.bias' instead.\n", |
| 121 | + " return torch.stack([block.attn.b_V for block in self.blocks], dim=0)\n", |
| 122 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:657: FutureWarning: Hook 'b_in' is deprecated and will be removed in a future version. Use 'in.bias' instead.\n", |
| 123 | + " return torch.stack([block.mlp.b_in for block in self.blocks], dim=0)\n", |
| 124 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:662: FutureWarning: Hook 'b_out' is deprecated and will be removed in a future version. Use 'out.bias' instead.\n", |
| 125 | + " return torch.stack([block.mlp.b_out for block in self.blocks], dim=0)\n", |
| 126 | + "`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.\n", |
| 127 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_resid_pre' is deprecated and will be removed in a future version. Use 'hook_in' instead.\n", |
| 128 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 129 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_resid_mid' is deprecated and will be removed in a future version. Use 'attn.hook_out' instead.\n", |
| 130 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 131 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_resid_post' is deprecated and will be removed in a future version. Use 'hook_out' instead.\n", |
| 132 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 133 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_attn_in' is deprecated and will be removed in a future version. Use 'attn.hook_in' instead.\n", |
| 134 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 135 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_attn_out' is deprecated and will be removed in a future version. Use 'attn.hook_out' instead.\n", |
| 136 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 137 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_q_input' is deprecated and will be removed in a future version. Use 'attn.q.hook_in' instead.\n", |
| 138 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 139 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_k_input' is deprecated and will be removed in a future version. Use 'attn.k.hook_in' instead.\n", |
| 140 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 141 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_v_input' is deprecated and will be removed in a future version. Use 'attn.v.hook_in' instead.\n", |
| 142 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 143 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_mlp_in' is deprecated and will be removed in a future version. Use 'mlp.hook_in' instead.\n", |
| 144 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 145 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_mlp_out' is deprecated and will be removed in a future version. Use 'mlp.hook_out' instead.\n", |
| 146 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 147 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_normalized' is deprecated and will be removed in a future version. Use 'hook_out' instead.\n", |
| 148 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 149 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_scale' is deprecated and will be removed in a future version. Use 'hook_out' instead.\n", |
| 150 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 151 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_result' is deprecated and will be removed in a future version. Use 'hook_out' instead.\n", |
| 152 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 153 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_q' is deprecated and will be removed in a future version. Use 'q.hook_out' instead.\n", |
| 154 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 155 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_k' is deprecated and will be removed in a future version. Use 'k.hook_out' instead.\n", |
| 156 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 157 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_v' is deprecated and will be removed in a future version. Use 'v.hook_out' instead.\n", |
| 158 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 159 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_z' is deprecated and will be removed in a future version. Use 'hook_hidden_states' instead.\n", |
| 160 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 161 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_pre' is deprecated and will be removed in a future version. Use 'in.hook_out' instead.\n", |
| 162 | + " component_hooks = mod.get_hooks() # type: ignore\n", |
| 163 | + "/Users/bryce/Projects/Lingwave/TransformerLens/transformer_lens/model_bridge/bridge.py:156: FutureWarning: Hook 'hook_post' is deprecated and will be removed in a future version. Use 'out.hook_in' instead.\n", |
| 164 | + " component_hooks = mod.get_hooks() # type: ignore\n" |
119 | 165 | ] |
120 | 166 | } |
121 | 167 | ], |
122 | 168 | "source": [ |
123 | 169 | "# NBVAL_IGNORE_OUTPUT\n", |
124 | | - "model_name = \"gpt2-small\"\n", |
125 | | - "model = HookedTransformer.from_pretrained(model_name)" |
| 170 | + "model_name = \"gpt2\"\n", |
| 171 | + "model = TransformerBridge.boot_transformers(model_name)\n", |
| 172 | + "model.enable_compatibility_mode()" |
126 | 173 | ] |
127 | 174 | }, |
128 | 175 | { |
129 | 176 | "cell_type": "code", |
130 | | - "execution_count": 4, |
| 177 | + "execution_count": 6, |
131 | 178 | "metadata": {}, |
132 | 179 | "outputs": [], |
133 | 180 | "source": [ |
|
454 | 501 | ], |
455 | 502 | "metadata": { |
456 | 503 | "kernelspec": { |
457 | | - "display_name": "Python 3.7.13 ('base')", |
| 504 | + "display_name": ".venv", |
458 | 505 | "language": "python", |
459 | 506 | "name": "python3" |
460 | 507 | }, |
|
468 | 515 | "name": "python", |
469 | 516 | "nbconvert_exporter": "python", |
470 | 517 | "pygments_lexer": "ipython3", |
471 | | - "version": "3.11.9" |
| 518 | + "version": "3.12.7" |
472 | 519 | }, |
473 | | - "orig_nbformat": 4, |
474 | | - "vscode": { |
475 | | - "interpreter": { |
476 | | - "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" |
477 | | - } |
478 | | - } |
| 520 | + "orig_nbformat": 4 |
479 | 521 | }, |
480 | 522 | "nbformat": 4, |
481 | 523 | "nbformat_minor": 2 |
|
0 commit comments