Skip to content

Commit c2e2d90

Browse files
committed
fixed doc string issues
1 parent 7c3ff2e commit c2e2d90

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

transformer_lens/model_bridge/generalized_components/joint_gate_up_mlp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
This module contains the bridge component for MLP layers with joint gating and up-projection.
44
"""
55

6+
from __future__ import annotations
67

78
from typing import Any, Dict, Optional
89

910
import torch
1011

12+
from transformer_lens.model_bridge.generalized_components.base import (
13+
GeneralizedComponent,
14+
)
15+
from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
1116
from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge
1217

1318

@@ -34,6 +39,7 @@ def __init__(
3439
gate_up_config: Gate_Up-specific configuration which holds function to split the joint projection into two
3540
"""
3641
super().__init__(name, model_config, submodules=submodules)
42+
self.gate_up_config = gate_up_config or {}
3743
self.gate = LinearBridge(name="gate", config=model_config)
3844
self.up = LinearBridge(name="up", config=model_config)
3945

0 commit comments

Comments
 (0)