Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ docs/source/generated
# docs/source/_static/model_table
**.orig
.venv

.env
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"notebook.formatOnSave.enabled": true,
"pylint.importStrategy": "fromEnvironment",
"python.testing.pytestArgs": [
"transformer_lens",
"tests"
],
"python.testing.pytestEnabled": true,
"rewrap.autoWrap.enabled": true,
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/test_hooked_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
def test_cuda(mlm_tokens):
def test_cuda(tokens):
model = HookedEncoder.from_pretrained(MODEL_NAME)
model(mlm_tokens)
model(tokens)
2 changes: 1 addition & 1 deletion tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_cache_device():
torch.device("cuda:1")
)

logits, cache = model.run_with_cache("Hello there", device="cpu")
logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu"))
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))

model.to("cuda")
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def test_attention_load_in_4bit():
assert torch.all(attn.b_V == 0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_attention_forward_half_precisions(dtype):
# Construct a small attention block
cfg = HookedTransformerConfig(
d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype
)
attn = Attention(cfg)
# Random inputs in the matching dtype
batch = 1
seq = 4
x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda")
# Run forward through attention (q,k,v = x)
out = attn(x, x, x)
# Should not raise and return a tensor on cuda with same dtype as cfg or compatible
assert isinstance(out, torch.Tensor)
assert out.device.type == "cuda"


def test_attention_config_dict():
cfg = {
"n_layers": 12,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/factored_matrix/test_multiply_by_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
],
ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"],
)
@pytest.mark.parametrize("leading_dim", [False, True])
@pytest.mark.parametrize("multiply_from_left", [False, True])
Expand Down
25 changes: 18 additions & 7 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from better_abc import abstract_attribute
from jaxtyping import Float, Int
from torch import Tensor
from transformers.utils.import_utils import is_bitsandbytes_available

from transformer_lens.cache.key_value_cache_entry import (
Expand Down Expand Up @@ -280,8 +281,7 @@ def forward(
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
pattern = pattern.to(device=v.device, dtype=v.dtype)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
Expand All @@ -301,15 +301,21 @@ def forward(
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
)

if self.b_O.device != w.device:
w = w.to(self.b_O.device)
if self.b_O.device != z.device:
z = z.to(self.b_O.device)
# Move output projection weights and bias to the same device as z
# so that the final linear operation occurs on the device of the inputs
if w.device != z.device:
w = w.to(z.device)
b_O: Tensor = self.b_O
if b_O.device != z.device:
b_O = b_O.to(z.device)
# Ensure z has the same dtype as weights used in the output projection
if z.dtype != w.dtype:
z = z.to(w.dtype)

out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
self.b_O,
b_O,
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook
Expand All @@ -329,6 +335,11 @@ def forward(
self.W_O,
"head_index d_head d_model -> 1 1 head_index d_head d_model",
)
if w.device != z.device:
w = w.to(z.device)
# Ensure z has the same dtype as w before multiplication
if z.dtype != w.dtype:
z = z.to(w.dtype)
z = einops.rearrange(
z, "batch pos head_index d_head -> batch pos head_index d_head 1"
)
Expand Down
34 changes: 33 additions & 1 deletion transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from transformer_lens.model_bridge.get_params_util import get_bridge_params
from transformer_lens.utilities.aliases import resolve_alias
from transformer_lens.utilities.devices import move_to_and_update_config

if TYPE_CHECKING:
from transformer_lens.ActivationCache import ActivationCache
Expand Down Expand Up @@ -1754,7 +1755,7 @@ def generate(
return output_tokens

def to(self, *args, **kwargs) -> "TransformerBridge":
"""Move model to device or change dtype.
"""Move model to device and/or change dtype.

Args:
args: Positional arguments for nn.Module.to
Expand All @@ -1763,6 +1764,37 @@ def to(self, *args, **kwargs) -> "TransformerBridge":
Returns:
Self for chaining
"""
# Extract print_details if provided
print_details = kwargs.pop("print_details", True)

# Handle both device and dtype changes
# torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
# to(device=...), to(dtype=...), to(device=..., dtype=...)
target_device, target_dtype = None, None

if len(args) >= 1:
first_arg = args[0]
if isinstance(first_arg, (torch.device, str)):
target_device = first_arg
elif isinstance(first_arg, torch.dtype):
target_dtype = first_arg
if len(args) >= 2:
second_arg = args[1]
if isinstance(second_arg, torch.dtype):
target_dtype = second_arg

# these override positional args
if "device" in kwargs:
target_device = kwargs["device"]
if "dtype" in kwargs:
target_dtype = kwargs["dtype"]

if target_device is not None:
move_to_and_update_config(self, target_device, print_details)
if target_dtype is not None:
move_to_and_update_config(self, target_dtype, print_details)

# Move the original model with all original args/kwargs (with print_details removed)
self.original_model = self.original_model.to(*args, **kwargs)
return self

Expand Down