Skip to content

Be more efficient with corrupted caching #89

@ArthurConmy

Description

@ArthurConmy
for k in exp.global_cache.corrupted_cache.keys():
    print(k, exp.global_cache.corrupted_cache[k].shape, k in exp.global_cache.online_cache)

returns lots of unnecessary things:

blocks.0.ln1.hook_scale torch.Size([40, 41, 8, 1]) False
blocks.0.ln1.hook_normalized torch.Size([40, 41, 8, 512]) False
...
blocks.0.attn.hook_attn_scores torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_pattern torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_z torch.Size([40, 41, 8, 64]) False

we should only be caching things that definitely matter. This should save memory and allow us to use bigger models when corrupted_cache_cpu=False

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions