Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 65c3047

Browse files
author
Ubuntu
committed
add adapter
1 parent 1e51262 commit 65c3047

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/gluonnlp/models/transformer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Optional, Tuple, List
1717
from ..utils.registry import Registry
1818
from ..attention_cell import MultiHeadAttentionCell, gen_self_attn_mask, gen_mem_attn_mask
19-
from ..layers import PositionalEmbedding, PositionwiseFFN, InitializerType
19+
from ..layers import PositionalEmbedding, PositionwiseFFN, InitializerType, AdapterModule
2020
from ..utils.config import CfgNode as CN
2121
from ..sequence_sampler import BaseStepDecoder
2222

@@ -149,7 +149,9 @@ def __init__(self,
149149
bias_initializer: Optional[InitializerType] = 'zeros',
150150
activation: str = 'relu',
151151
dtype='float32',
152-
layout='NT'):
152+
layout='NT',
153+
use_adapter=False,
154+
adapter_config={}):
153155
"""
154156
155157
Parameters
@@ -186,6 +188,8 @@ def __init__(self,
186188
self._pre_norm = pre_norm
187189
self._dtype = dtype
188190
self._layout = layout
191+
self._use_adapter = use_adapter
192+
self._adapter_config = adapter_config
189193
assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \
190194
'Only "TN" and "NT" are accepted!'.format(layout)
191195
assert self._units % self._num_heads == 0, 'units must be divisive by the number of heads'
@@ -204,6 +208,9 @@ def __init__(self,
204208
weight_initializer=weight_initializer,
205209
bias_initializer=bias_initializer,
206210
dtype=self._dtype)
211+
212+
if self._use_adapter:
213+
self.adapter_layer_attn = AdapterModule(in_units=units, adapter_config=adapter_config)
207214
attention_layout = 'NTK' if self._layout == 'NT' else 'TNK'
208215
self.attention_cell = \
209216
MultiHeadAttentionCell(
@@ -225,7 +232,9 @@ def __init__(self,
225232
layer_norm_eps=layer_norm_eps,
226233
activation=activation,
227234
pre_norm=pre_norm,
228-
dtype=self._dtype)
235+
dtype=self._dtype,
236+
use_adapter=self._use_adapter,
237+
adapter_config=self._adapter_config)
229238

230239
@property
231240
def layout(self) -> str:
@@ -265,6 +274,8 @@ def forward(self, data, attn_mask):
265274
out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask)
266275
out = self.attention_proj(out)
267276
out = self.dropout_layer(out)
277+
if self._use_adapter:
278+
out = self.adapter_layer_attn(out)
268279
out = out + data
269280
if not self._pre_norm:
270281
out = self.layer_norm(out)

0 commit comments

Comments
 (0)