Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ docs/source/generated
# docs/source/_static/model_table
**.orig
.venv

.env
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"notebook.formatOnSave.enabled": true,
"pylint.importStrategy": "fromEnvironment",
"python.testing.pytestArgs": [
"transformer_lens",
"tests"
],
"python.testing.pytestEnabled": true,
"rewrap.autoWrap.enabled": true,
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/test_hooked_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
def test_cuda(mlm_tokens):
def test_cuda(tokens):
model = HookedEncoder.from_pretrained(MODEL_NAME)
model(mlm_tokens)
model(tokens)
2 changes: 1 addition & 1 deletion tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_cache_device():
torch.device("cuda:1")
)

logits, cache = model.run_with_cache("Hello there", device="cpu")
logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu"))
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))

model.to("cuda")
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def test_attention_load_in_4bit():
assert torch.all(attn.b_V == 0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_attention_forward_half_precisions(dtype):
# Construct a small attention block
cfg = HookedTransformerConfig(
d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype
)
attn = Attention(cfg)
# Random inputs in the matching dtype
batch = 1
seq = 4
x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda")
# Run forward through attention (q,k,v = x)
out = attn(x, x, x)
# Should not raise and return a tensor on cuda with same dtype as cfg or compatible
assert isinstance(out, torch.Tensor)
assert out.device.type == "cuda"


def test_attention_config_dict():
cfg = {
"n_layers": 12,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/factored_matrix/test_multiply_by_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
],
ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"],
)
@pytest.mark.parametrize("leading_dim", [False, True])
@pytest.mark.parametrize("multiply_from_left", [False, True])
Expand Down
22 changes: 15 additions & 7 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from better_abc import abstract_attribute
from jaxtyping import Float, Int
from torch import Tensor
from transformers.utils.import_utils import is_bitsandbytes_available

from transformer_lens.cache.key_value_cache_entry import (
Expand Down Expand Up @@ -280,8 +281,7 @@ def forward(
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
pattern = pattern.to(device=v.device, dtype=v.dtype)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
Expand All @@ -301,15 +301,21 @@ def forward(
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
)

if self.b_O.device != w.device:
w = w.to(self.b_O.device)
if self.b_O.device != z.device:
z = z.to(self.b_O.device)
# Move output projection weights and bias to the same device as z
# so that the final linear operation occurs on the device of the inputs
if w.device != z.device:
w = w.to(z.device)
b_O: Tensor = self.b_O
if b_O.device != z.device:
b_O = b_O.to(z.device)
# Ensure z has the same dtype as weights used in the output projection
if z.dtype != w.dtype:
z = z.to(w.dtype)

out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
self.b_O,
b_O,
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook
Expand All @@ -329,6 +335,8 @@ def forward(
self.W_O,
"head_index d_head d_model -> 1 1 head_index d_head d_model",
)
if w.device != z.device:
w = w.to(z.device)
z = einops.rearrange(
z, "batch pos head_index d_head -> batch pos head_index d_head 1"
)
Expand Down
34 changes: 33 additions & 1 deletion transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from transformer_lens.model_bridge.get_params_util import get_bridge_params
from transformer_lens.utilities.aliases import resolve_alias
from transformer_lens.utilities.devices import move_to_and_update_config

if TYPE_CHECKING:
from transformer_lens.ActivationCache import ActivationCache
Expand Down Expand Up @@ -1754,7 +1755,7 @@ def generate(
return output_tokens

def to(self, *args, **kwargs) -> "TransformerBridge":
"""Move model to device or change dtype.
"""Move model to device and/or change dtype.

Args:
args: Positional arguments for nn.Module.to
Expand All @@ -1763,6 +1764,37 @@ def to(self, *args, **kwargs) -> "TransformerBridge":
Returns:
Self for chaining
"""
# Extract print_details if provided
print_details = kwargs.pop("print_details", True)

# Handle both device and dtype changes
# torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
# to(device=...), to(dtype=...), to(device=..., dtype=...)
target_device, target_dtype = None, None

if len(args) >= 1:
first_arg = args[0]
if isinstance(first_arg, (torch.device, str)):
target_device = first_arg
elif isinstance(first_arg, torch.dtype):
target_dtype = first_arg
if len(args) >= 2:
second_arg = args[1]
if isinstance(second_arg, torch.dtype):
target_dtype = second_arg

# these override positional args
if "device" in kwargs:
target_device = kwargs["device"]
if "dtype" in kwargs:
target_dtype = kwargs["dtype"]

if target_device is not None:
move_to_and_update_config(self, target_device, print_details)
if target_dtype is not None:
move_to_and_update_config(self, target_dtype, print_details)

# Move the original model with all original args/kwargs (with print_details removed)
self.original_model = self.original_model.to(*args, **kwargs)
return self

Expand Down