diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 9a802831e..8d06a39db 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -226,7 +226,7 @@ jobs: fail-fast: false matrix: notebook: - # - "Activation_Patching_in_TL_Demo" + - "Activation_Patching_in_TL_Demo" # - "Attribution_Patching_Demo" - "ARENA_Content" - "BERT" diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index 3be728cb1..1e57298c1 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -58,7 +58,7 @@ " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n", + " %pip install transformer_lens\n", " # Install my janky personal plotting utils\n", " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", "except:\n", @@ -67,9 +67,9 @@ " from IPython import get_ipython\n", "\n", " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " # Code to automatically update the TransformerBridge code as its edited without restarting the kernel\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" ] }, { @@ -127,11 +127,7 @@ "source": [ "import transformer_lens\n", "import transformer_lens.utils as utils\n", - "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", - " HookPoint,\n", - ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens.model_bridge import TransformerBridge" ] }, { @@ -215,7 +211,8 @@ } ], "source": [ - "model = HookedTransformer.from_pretrained(\"gpt2-small\")" + "model = TransformerBridge.boot_transformers(\"gpt2\")\n", + "model.enable_compatibility_mode()" ] }, { @@ -955,7 +952,8 @@ } ], "source": [ - "attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n", + "attn_only = TransformerBridge.boot_transformers(\"attn-only-2l\")\n", + "attn_only.enable_compatibility_mode()\n", "batch = 4\n", "seq_len = 20\n", "rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n",