diff --git a/.gitignore b/.gitignore index 32da76df8..be879b728 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ docs/source/generated # docs/source/_static/model_table **.orig .venv - +.env diff --git a/.vscode/settings.json b/.vscode/settings.json index 63e6e310a..86d448657 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,7 +33,7 @@ "notebook.formatOnSave.enabled": true, "pylint.importStrategy": "fromEnvironment", "python.testing.pytestArgs": [ - "transformer_lens", + "tests" ], "python.testing.pytestEnabled": true, "rewrap.autoWrap.enabled": true, diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py index d0f746d60..797ecbbf9 100644 --- a/tests/acceptance/test_hooked_encoder.py +++ b/tests/acceptance/test_hooked_encoder.py @@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") -def test_cuda(mlm_tokens): +def test_cuda(tokens): model = HookedEncoder.from_pretrained(MODEL_NAME) - model(mlm_tokens) + model(tokens) diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 3af5eeeb2..ad407eb6e 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -111,7 +111,7 @@ def test_cache_device(): torch.device("cuda:1") ) - logits, cache = model.run_with_cache("Hello there", device="cpu") + logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu")) assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index b386660c6..c473cc491 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -80,6 +80,25 @@ def test_attention_load_in_4bit(): assert torch.all(attn.b_V == 0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_attention_forward_half_precisions(dtype): + # Construct a small attention block + cfg = HookedTransformerConfig( + d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype + ) + attn = Attention(cfg) + # Random inputs in the matching dtype + batch = 1 + seq = 4 + x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda") + # Run forward through attention (q,k,v = x) + out = attn(x, x, x) + # Should not raise and return a tensor on cuda with same dtype as cfg or compatible + assert isinstance(out, torch.Tensor) + assert out.device.type == "cuda" + + def test_attention_config_dict(): cfg = { "n_layers": 12, diff --git a/tests/unit/factored_matrix/test_multiply_by_scalar.py b/tests/unit/factored_matrix/test_multiply_by_scalar.py index 85d0bfbe7..d5fbf29ba 100644 --- a/tests/unit/factored_matrix/test_multiply_by_scalar.py +++ b/tests/unit/factored_matrix/test_multiply_by_scalar.py @@ -23,6 +23,7 @@ ), # Non-scalar Tensor. AssertionError expected. (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. ], + ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"], ) @pytest.mark.parametrize("leading_dim", [False, True]) @pytest.mark.parametrize("multiply_from_left", [False, True]) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5f026f493..0d144a741 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from better_abc import abstract_attribute from jaxtyping import Float, Int +from torch import Tensor from transformers.utils.import_utils import is_bitsandbytes_available from transformer_lens.cache.key_value_cache_entry import ( @@ -280,8 +281,7 @@ def forward( raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}") pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) + pattern = pattern.to(device=v.device, dtype=v.dtype) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: @@ -301,15 +301,21 @@ def forward( self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" ) - if self.b_O.device != w.device: - w = w.to(self.b_O.device) - if self.b_O.device != z.device: - z = z.to(self.b_O.device) + # Move output projection weights and bias to the same device as z + # so that the final linear operation occurs on the device of the inputs + if w.device != z.device: + w = w.to(z.device) + b_O: Tensor = self.b_O + if b_O.device != z.device: + b_O = b_O.to(z.device) + # Ensure z has the same dtype as weights used in the output projection + if z.dtype != w.dtype: + z = z.to(w.dtype) out = F.linear( z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), w, - self.b_O, + b_O, ) else: # Explicitly calculate the attention result so it can be accessed by a hook @@ -329,6 +335,11 @@ def forward( self.W_O, "head_index d_head d_model -> 1 1 head_index d_head d_model", ) + if w.device != z.device: + w = w.to(z.device) + # Ensure z has the same dtype as w before multiplication + if z.dtype != w.dtype: + z = z.to(w.dtype) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 141af5e2e..81fbfe406 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -39,6 +39,7 @@ ) from transformer_lens.model_bridge.get_params_util import get_bridge_params from transformer_lens.utilities.aliases import resolve_alias +from transformer_lens.utilities.devices import move_to_and_update_config if TYPE_CHECKING: from transformer_lens.ActivationCache import ActivationCache @@ -1754,7 +1755,7 @@ def generate( return output_tokens def to(self, *args, **kwargs) -> "TransformerBridge": - """Move model to device or change dtype. + """Move model to device and/or change dtype. Args: args: Positional arguments for nn.Module.to @@ -1763,6 +1764,37 @@ def to(self, *args, **kwargs) -> "TransformerBridge": Returns: Self for chaining """ + # Extract print_details if provided + print_details = kwargs.pop("print_details", True) + + # Handle both device and dtype changes + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # to(device=...), to(dtype=...), to(device=..., dtype=...) + target_device, target_dtype = None, None + + if len(args) >= 1: + first_arg = args[0] + if isinstance(first_arg, (torch.device, str)): + target_device = first_arg + elif isinstance(first_arg, torch.dtype): + target_dtype = first_arg + if len(args) >= 2: + second_arg = args[1] + if isinstance(second_arg, torch.dtype): + target_dtype = second_arg + + # these override positional args + if "device" in kwargs: + target_device = kwargs["device"] + if "dtype" in kwargs: + target_dtype = kwargs["dtype"] + + if target_device is not None: + move_to_and_update_config(self, target_device, print_details) + if target_dtype is not None: + move_to_and_update_config(self, target_dtype, print_details) + + # Move the original model with all original args/kwargs (with print_details removed) self.original_model = self.original_model.to(*args, **kwargs) return self