From 2be123d8908556ac5dad58953827f9574ca70a2d Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sun, 16 Nov 2025 11:00:09 -0800 Subject: [PATCH 1/5] Address device/dtype mismatches that caused failures in various contexts. We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite Core Fixes: ----------- transformer_lens/components/abstract_attention.py: - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases where tensors are upcast to float32 for numerical stability while cfg.dtype remains float16/bfloat16 - Add explicit device/dtype synchronization for output projection: * Move weights (W_O) and bias (b_O) to match input device (z.device) * Ensure z matches weight dtype before final linear operation transformer_lens/model_bridge/bridge.py: - Replace direct original_model.to() call with move_to_and_update_config() utility to ensure: * All bridge components (not just original_model) are moved to target device * cfg.device and cfg.dtype stay synchronized with actual model state * Multi-GPU cache tensors remain on correct devices Test Fixes: ----------- tests/acceptance/test_hooked_encoder.py: - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens' tests/acceptance/test_multi_gpu.py: - Update test_cache_device() to pass torch.device("cpu") instead of string "cpu" for proper device type validation tests/unit/components/test_attention.py: - Add test_attention_forward_half_precisions() to validate attention works correctly with bfloat16/float16 dtypes on CUDA devices tests/unit/factored_matrix/test_multiply_by_scalar.py: - Add test IDs to parametrize decorators to avoid pytest cache issues when random numbers appear in test names Tests Fixed by This Commit: --------------------------- - tests/acceptance/test_multi_gpu.py::test_cache_device - tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2] - tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2] - tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0] - tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1] - tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0] - tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1] - tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2] --- .gitignore | 2 +- .vscode/settings.json | 2 +- tests/acceptance/test_hooked_encoder.py | 4 ++-- tests/acceptance/test_multi_gpu.py | 2 +- tests/unit/components/test_attention.py | 19 +++++++++++++++++ .../test_multiply_by_scalar.py | 1 + .../components/abstract_attention.py | 21 ++++++++++++------- transformer_lens/model_bridge/bridge.py | 6 +++++- 8 files changed, 44 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 32da76df8..be879b728 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ docs/source/generated # docs/source/_static/model_table **.orig .venv - +.env diff --git a/.vscode/settings.json b/.vscode/settings.json index 63e6e310a..86d448657 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,7 +33,7 @@ "notebook.formatOnSave.enabled": true, "pylint.importStrategy": "fromEnvironment", "python.testing.pytestArgs": [ - "transformer_lens", + "tests" ], "python.testing.pytestEnabled": true, "rewrap.autoWrap.enabled": true, diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py index e62d35574..3afa69561 100644 --- a/tests/acceptance/test_hooked_encoder.py +++ b/tests/acceptance/test_hooked_encoder.py @@ -222,6 +222,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") -def test_cuda(mlm_tokens): +def test_cuda(tokens): model = HookedEncoder.from_pretrained(MODEL_NAME) - model(mlm_tokens) + model(tokens) diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 3af5eeeb2..ad407eb6e 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -111,7 +111,7 @@ def test_cache_device(): torch.device("cuda:1") ) - logits, cache = model.run_with_cache("Hello there", device="cpu") + logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu")) assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index b386660c6..c473cc491 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -80,6 +80,25 @@ def test_attention_load_in_4bit(): assert torch.all(attn.b_V == 0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_attention_forward_half_precisions(dtype): + # Construct a small attention block + cfg = HookedTransformerConfig( + d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype + ) + attn = Attention(cfg) + # Random inputs in the matching dtype + batch = 1 + seq = 4 + x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda") + # Run forward through attention (q,k,v = x) + out = attn(x, x, x) + # Should not raise and return a tensor on cuda with same dtype as cfg or compatible + assert isinstance(out, torch.Tensor) + assert out.device.type == "cuda" + + def test_attention_config_dict(): cfg = { "n_layers": 12, diff --git a/tests/unit/factored_matrix/test_multiply_by_scalar.py b/tests/unit/factored_matrix/test_multiply_by_scalar.py index 85d0bfbe7..d5fbf29ba 100644 --- a/tests/unit/factored_matrix/test_multiply_by_scalar.py +++ b/tests/unit/factored_matrix/test_multiply_by_scalar.py @@ -23,6 +23,7 @@ ), # Non-scalar Tensor. AssertionError expected. (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. ], + ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"], ) @pytest.mark.parametrize("leading_dim", [False, True]) @pytest.mark.parametrize("multiply_from_left", [False, True]) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5f026f493..7894d9a84 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -280,8 +280,7 @@ def forward( raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}") pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) + pattern = pattern.to(device=v.device, dtype=v.dtype) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: @@ -301,15 +300,21 @@ def forward( self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" ) - if self.b_O.device != w.device: - w = w.to(self.b_O.device) - if self.b_O.device != z.device: - z = z.to(self.b_O.device) + # Move output projection weights and bias to the same device as z + # so that the final linear operation occurs on the device of the inputs + if w.device != z.device: + w = w.to(z.device) + b_O = self.b_O + if b_O.device != z.device: + b_O = b_O.to(z.device) + # Ensure z has the same dtype as weights used in the output projection + if z.dtype != w.dtype: + z = z.to(w.dtype) out = F.linear( z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), w, - self.b_O, + b_O, ) else: # Explicitly calculate the attention result so it can be accessed by a hook @@ -329,6 +334,8 @@ def forward( self.W_O, "head_index d_head d_model -> 1 1 head_index d_head d_model", ) + if w.device != z.device: + w = w.to(z.device) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index e1cde6c0e..3920f8402 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -6082,7 +6082,11 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Returns: Self for chaining """ - self.original_model = self.original_model.to(*args, **kwargs) + # Use the shared utility which also updates `cfg` on device/dtype changes + from transformer_lens.utilities.devices import move_to_and_update_config + + # Move underlying model (and update config) instead of directly calling nn.Module.to + move_to_and_update_config(self, *args, **kwargs) return self def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": From f2efb3752625aba082f8719506624ad5ede49a37 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 26 Nov 2025 14:55:21 -0800 Subject: [PATCH 2/5] Align TransformerBridge.to() with PyTorch nn.Module semantics Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls. --- transformer_lens/model_bridge/bridge.py | 40 +++++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index f83169da8..08dd2e001 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -39,6 +39,7 @@ ) from transformer_lens.model_bridge.get_params_util import get_bridge_params from transformer_lens.utilities.aliases import resolve_alias +from transformer_lens.utilities.devices import move_to_and_update_config if TYPE_CHECKING: from transformer_lens.ActivationCache import ActivationCache @@ -1754,7 +1755,7 @@ def generate( return output_tokens def to(self, *args, **kwargs) -> "TransformerBridge": - """Move model to device or change dtype. + """Move model to device and/or change dtype. Args: args: Positional arguments for nn.Module.to @@ -1763,11 +1764,38 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Returns: Self for chaining """ - # Use the shared utility which also updates `cfg` on device/dtype changes - from transformer_lens.utilities.devices import move_to_and_update_config - - # Move underlying model (and update config) instead of directly calling nn.Module.to - move_to_and_update_config(self, *args, **kwargs) + # Extract print_details if provided + print_details = kwargs.pop("print_details", True) + + # Handle both device and dtype changes + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # to(device=...), to(dtype=...), to(device=..., dtype=...) + target_device, target_dtype = None, None + + if len(args) >= 1: + first_arg = args[0] + if isinstance(first_arg, (torch.device, str)): + target_device = first_arg + elif isinstance(first_arg, torch.dtype): + target_dtype = first_arg + if len(args) >= 2: + second_arg = args[1] + if isinstance(second_arg, torch.dtype): + target_dtype = second_arg + + # these override positional args + if "device" in kwargs: + target_device = kwargs["device"] + if "dtype" in kwargs: + target_dtype = kwargs["dtype"] + + if target_device is not None: + move_to_and_update_config(self, target_device, print_details) + if target_dtype is not None: + move_to_and_update_config(self, target_dtype, print_details) + + # Move the original model with all original args/kwargs (with print_details removed) + self.original_model = self.original_model.to(*args, **kwargs) return self def cuda(self, device: Optional[Union[int, torch.device]] = None) -> "TransformerBridge": From b4660f88a4a5d1c79b0e9c0e2c7e8e87098861d1 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 13:16:27 -0800 Subject: [PATCH 3/5] minor formatting and type fix --- transformer_lens/components/abstract_attention.py | 3 ++- transformer_lens/model_bridge/bridge.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 7894d9a84..a0db051a3 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from better_abc import abstract_attribute from jaxtyping import Float, Int from transformers.utils.import_utils import is_bitsandbytes_available @@ -304,7 +305,7 @@ def forward( # so that the final linear operation occurs on the device of the inputs if w.device != z.device: w = w.to(z.device) - b_O = self.b_O + b_O: Tensor = self.b_O if b_O.device != z.device: b_O = b_O.to(z.device) # Ensure z has the same dtype as weights used in the output projection diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 08dd2e001..81fbfe406 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1766,12 +1766,12 @@ def to(self, *args, **kwargs) -> "TransformerBridge": """ # Extract print_details if provided print_details = kwargs.pop("print_details", True) - + # Handle both device and dtype changes - # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), # to(device=...), to(dtype=...), to(device=..., dtype=...) target_device, target_dtype = None, None - + if len(args) >= 1: first_arg = args[0] if isinstance(first_arg, (torch.device, str)): @@ -1782,18 +1782,18 @@ def to(self, *args, **kwargs) -> "TransformerBridge": second_arg = args[1] if isinstance(second_arg, torch.dtype): target_dtype = second_arg - + # these override positional args if "device" in kwargs: target_device = kwargs["device"] if "dtype" in kwargs: target_dtype = kwargs["dtype"] - + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - + # Move the original model with all original args/kwargs (with print_details removed) self.original_model = self.original_model.to(*args, **kwargs) return self From ab280f18fc15b4f9e768776983d509a006ee23af Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 13:21:56 -0800 Subject: [PATCH 4/5] rerun isort fix --- transformer_lens/components/abstract_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index a0db051a3..02400f89f 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from better_abc import abstract_attribute from jaxtyping import Float, Int +from torch import Tensor from transformers.utils.import_utils import is_bitsandbytes_available from transformer_lens.cache.key_value_cache_entry import ( From 9d4e643589f1c527e470e5be2395dd6064fbfcd7 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sat, 29 Nov 2025 15:10:24 -0800 Subject: [PATCH 5/5] minor sync enhancement --- transformer_lens/components/abstract_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 02400f89f..0d144a741 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -337,6 +337,9 @@ def forward( ) if w.device != z.device: w = w.to(z.device) + # Ensure z has the same dtype as w before multiplication + if z.dtype != w.dtype: + z = z.to(w.dtype) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" )