File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
transformer_lens/model_bridge/generalized_components Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change 33This module contains the bridge component for MLP layers with joint gating and up-projection.
44"""
55
6+ from __future__ import annotations
67
78from typing import Any , Dict , Optional
89
910import torch
1011
12+ from transformer_lens .model_bridge .generalized_components .base import (
13+ GeneralizedComponent ,
14+ )
15+ from transformer_lens .model_bridge .generalized_components .linear import LinearBridge
1116from transformer_lens .model_bridge .generalized_components .mlp import MLPBridge
1217
1318
@@ -34,6 +39,7 @@ def __init__(
3439 gate_up_config: Gate_Up-specific configuration which holds function to split the joint projection into two
3540 """
3641 super ().__init__ (name , model_config , submodules = submodules )
42+ self .gate_up_config = gate_up_config or {}
3743 self .gate = LinearBridge (name = "gate" , config = model_config )
3844 self .up = LinearBridge (name = "up" , config = model_config )
3945
You can’t perform that action at this time.
0 commit comments