@@ -38,19 +38,20 @@ def forward(self, x):
3838
3939
4040class T5Attention (nn .Module ):
41- def __init__ (self , dim , dim_attn , num_heads , dropout = 0.0 ):
41+ def __init__ (self , dim , dim_attn , num_heads , dropout = 0.0 , device = "cuda:0" ):
4242 assert dim_attn % num_heads == 0
4343 super (T5Attention , self ).__init__ ()
4444 self .dim = dim
4545 self .dim_attn = dim_attn
4646 self .num_heads = num_heads
4747 self .head_dim = dim_attn // num_heads
48+ self .device = device
4849
4950 # layers
50- self .q = nn .Linear (dim , dim_attn , bias = False )
51- self .k = nn .Linear (dim , dim_attn , bias = False )
52- self .v = nn .Linear (dim , dim_attn , bias = False )
53- self .o = nn .Linear (dim_attn , dim , bias = False )
51+ self .q = nn .Linear (dim , dim_attn , bias = False , device = device )
52+ self .k = nn .Linear (dim , dim_attn , bias = False , device = device )
53+ self .v = nn .Linear (dim , dim_attn , bias = False , device = device )
54+ self .o = nn .Linear (dim_attn , dim , bias = False , device = device )
5455 self .dropout = nn .Dropout (dropout )
5556
5657 def forward (self , x , context = None , mask = None , pos_bias = None ):
@@ -90,15 +91,16 @@ def forward(self, x, context=None, mask=None, pos_bias=None):
9091
9192
9293class T5FeedForward (nn .Module ):
93- def __init__ (self , dim , dim_ffn , dropout = 0.0 ):
94+ def __init__ (self , dim , dim_ffn , dropout = 0.0 , device = "cuda:0" ):
9495 super (T5FeedForward , self ).__init__ ()
9596 self .dim = dim
9697 self .dim_ffn = dim_ffn
98+ self .device = device
9799
98100 # layers
99- self .gate = nn .Sequential (nn .Linear (dim , dim_ffn , bias = False ), GELU ())
100- self .fc1 = nn .Linear (dim , dim_ffn , bias = False )
101- self .fc2 = nn .Linear (dim_ffn , dim , bias = False )
101+ self .gate = nn .Sequential (nn .Linear (dim , dim_ffn , bias = False , device = device ), GELU ())
102+ self .fc1 = nn .Linear (dim , dim_ffn , bias = False , device = device )
103+ self .fc2 = nn .Linear (dim_ffn , dim , bias = False , device = device )
102104 self .dropout = nn .Dropout (dropout )
103105
104106 def forward (self , x ):
@@ -110,21 +112,22 @@ def forward(self, x):
110112
111113
112114class T5SelfAttention (nn .Module ):
113- def __init__ (self , dim , dim_attn , dim_ffn , num_heads , num_buckets , shared_pos = True , dropout = 0.0 ):
115+ def __init__ (self , dim , dim_attn , dim_ffn , num_heads , num_buckets , shared_pos = True , dropout = 0.0 , device = "cuda:0" ):
114116 super (T5SelfAttention , self ).__init__ ()
115117 self .dim = dim
116118 self .dim_attn = dim_attn
117119 self .dim_ffn = dim_ffn
118120 self .num_heads = num_heads
119121 self .num_buckets = num_buckets
120122 self .shared_pos = shared_pos
123+ self .device = device
121124
122125 # layers
123126 self .norm1 = T5LayerNorm (dim )
124- self .attn = T5Attention (dim , dim_attn , num_heads , dropout )
127+ self .attn = T5Attention (dim , dim_attn , num_heads , dropout , device )
125128 self .norm2 = T5LayerNorm (dim )
126- self .ffn = T5FeedForward (dim , dim_ffn , dropout )
127- self .pos_embedding = None if shared_pos else T5RelativeEmbedding (num_buckets , num_heads , bidirectional = True )
129+ self .ffn = T5FeedForward (dim , dim_ffn , dropout , device )
130+ self .pos_embedding = None if shared_pos else T5RelativeEmbedding (num_buckets , num_heads , bidirectional = True , device = device )
128131
129132 def forward (self , x , mask = None , pos_bias = None ):
130133 e = pos_bias if self .shared_pos else self .pos_embedding (x .size (1 ), x .size (1 ))
@@ -134,15 +137,16 @@ def forward(self, x, mask=None, pos_bias=None):
134137
135138
136139class T5RelativeEmbedding (nn .Module ):
137- def __init__ (self , num_buckets , num_heads , bidirectional , max_dist = 128 ):
140+ def __init__ (self , num_buckets , num_heads , bidirectional , max_dist = 128 , device = "cuda:0" ):
138141 super (T5RelativeEmbedding , self ).__init__ ()
139142 self .num_buckets = num_buckets
140143 self .num_heads = num_heads
141144 self .bidirectional = bidirectional
142145 self .max_dist = max_dist
146+ self .device = device
143147
144148 # layers
145- self .embedding = nn .Embedding (num_buckets , num_heads )
149+ self .embedding = nn .Embedding (num_buckets , num_heads , device = device )
146150
147151 def forward (self , lq , lk ):
148152 device = self .embedding .weight .device
@@ -257,12 +261,12 @@ def __init__(
257261 self .shared_pos = shared_pos
258262
259263 # layers
260- self .token_embedding = vocab if isinstance (vocab , nn .Embedding ) else nn .Embedding (vocab , dim )
261- self .pos_embedding = T5RelativeEmbedding (num_buckets , num_heads , bidirectional = True ) if shared_pos else None
264+ self .token_embedding = vocab if isinstance (vocab , nn .Embedding ) else nn .Embedding (vocab , dim , device = device )
265+ self .pos_embedding = T5RelativeEmbedding (num_buckets , num_heads , bidirectional = True , device = device ) if shared_pos else None
262266 self .dropout = nn .Dropout (dropout )
263267 self .blocks = nn .ModuleList (
264268 [
265- T5SelfAttention (dim , dim_attn , dim_ffn , num_heads , num_buckets , shared_pos , dropout )
269+ T5SelfAttention (dim , dim_attn , dim_ffn , num_heads , num_buckets , shared_pos , dropout , device )
266270 for _ in range (num_layers )
267271 ]
268272 )
0 commit comments