diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index 49c4655d4..b249d112e 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -65,7 +65,7 @@ "from typing import List, Callable, Tuple, Union\n", "from functools import partial\n", "from jaxtyping import Float\n", - "from transformer_lens import HookedTransformer\n", + "from transformer_lens.model_bridge import TransformerBridge\n", "from transformer_lens.ActivationCache import ActivationCache\n", "import transformer_lens.utils as utils\n", "from transformer_lens.hook_points import (\n", @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -217,7 +217,8 @@ "source": [ "# NBVAL_IGNORE_OUTPUT\n", "# I'm using an M2 macbook air, so I use CPU for better support\n", - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n", + "model = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\n", + "model.enable_compatibility_mode()\n", "model.eval()" ] }, @@ -263,17 +264,17 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", + "def get_source_representation(prompts: List[str], layer_id: int, model: TransformerBridge, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", " \n", " Args:\n", " - prompts (List[str]): a list of source prompts\n", " - layer_id (int): the layer id of the model\n", - " - model (HookedTransformer): the source model\n", + " - model (TransformerBridge): the source model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", "\n", " Returns:\n", @@ -325,19 +326,19 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", - "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", + "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: TransformerBridge, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", " \"\"\"Feed the source hidden representation to the target model\n", " \n", " Args:\n", " - source_rep (torch.Tensor): the source hidden representation\n", " - prompt (List[str]): the target prompt\n", " - f (Callable): the mapping function\n", - " - model (HookedTransformer): the target model\n", + " - model (TransformerBridge): the target model\n", " - layer_id (int): the layer id of the target model\n", " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", " \"\"\"\n", @@ -417,11 +418,11 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", + "def generate_with_patching(model: TransformerBridge, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", " temp_prompts = prompts\n", " input_tokens = model.to_tokens(temp_prompts)\n", " for _ in range(max_new_tokens):\n",