Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions scripts/classification/train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import json
import random
import pandas as pd
import mxnet.numpy_extension as _mx_npx
import os
import json
import logging
import time
import argparse
Expand Down Expand Up @@ -92,13 +94,27 @@ def parse_args():
help='the path to training dataset')
parser.add_argument('--warmup_ratio', type=float, default=0.1,
help='Ratio of warmup steps in the learning rate scheduler.')
parser.add_argument('--method', type=str, default='full', choices=['full', 'bias', 'subbias', 'adapter'],
help='different finetune method')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to edit the README file to include results for (at least some of) the different choices (and references to the papers)?



args = parser.parse_args()
return args


def change_adapter_cfg(cfg, task):
adapter_config = {'adapter_fusion':False,
'task_names':[task.task_name],
task.task_name:{'type':'Basic','unit':64}}
cfg.defrost()
cfg.MODEL.use_adapter = True
cfg.MODEL.adapter_config = json.dumps(adapter_config)
cfg.freeze()
return cfg

def get_network(model_name,
ctx_l,
method='full',
checkpoint_path=None,
backbone_path=None,
task=None):
Expand All @@ -109,13 +125,16 @@ def get_network(model_name,
use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Model, cfg, tokenizer, download_params_path, _ = \
get_backbone(model_name, load_backbone=not backbone_path)

if method == 'adapter':
cfg = change_adapter_cfg(cfg, task)
backbone = Model.from_cfg(cfg)
# Load local backbone parameters if backbone_path provided.
# Otherwise, download backbone parameters from gluon zoo.

backbone_params_path = backbone_path if backbone_path else download_params_path
if checkpoint_path is None:
backbone.load_parameters(backbone_params_path, ignore_extra=True,
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=True,
Copy link
Contributor

@leezu leezu Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the following be safer?

Suggested change
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=True,
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=(method == 'adapter'),

ctx=ctx_l, cast_dtype=True)
num_params, num_fixed_params \
= count_parameters(deduplicate_param_dict(backbone.collect_params()))
Expand Down Expand Up @@ -219,6 +238,8 @@ def train(args):
#random seed
set_seed(args.seed)
level = logging.INFO
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
detail_dir = os.path.join(args.output_dir, args.task_name)
if not os.path.exists(detail_dir):
os.mkdir(detail_dir)
Expand All @@ -228,11 +249,12 @@ def train(args):
console=(local_rank == 0))
logging.info(args)
cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)


logging.info('Prepare training data')
train_data, _ = get_task_data(args, task, tokenizer, segment='train')
train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
Expand All @@ -253,6 +275,22 @@ def train(args):
sampler=sampler)


if args.method == 'full':
target_params_name = classify_net.collect_params().keys()
elif args.method == 'bias':
target_params_name = [key
for key in classify_net.collect_params() if
key.endswith('bias') or key.endswith('beta') or 'out_proj' in key]
elif args.method == 'adapter':
target_params_name = [key
for key in classify_net.collect_params() if
'adapter' in key or 'out_proj' in key]
for name in classify_net.collect_params():
if name not in target_params_name:
classify_net.collect_params()[name].grad_req = 'null'

target_params = {name:classify_net.collect_params()[name] for name in target_params_name}


param_dict = classify_net.collect_params()
# Do not apply weight decay to all the LayerNorm and bias
Expand All @@ -269,7 +307,7 @@ def train(args):
if local_rank == 0:
writer = SummaryWriter(logdir=os.path.join(args.output_dir,
args.task_name + '_tensorboard_' +
str(args.lr) + '_' + str(args.epochs)))
str(args.lr) + '_' + str(args.epochs) + '_' + str(args.method)))
if args.comm_backend == 'horovod':
# Horovod: fetch and broadcast parameters
hvd.broadcast_parameters(param_dict, root_rank=0)
Expand All @@ -290,10 +328,12 @@ def train(args):
optimizer_params = {'learning_rate': args.lr,
'wd': args.wd,
'lr_scheduler': lr_scheduler}


if args.comm_backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
trainer = hvd.DistributedTrainer(target_params, args.optimizer, optimizer_params)
else:
trainer = mx.gluon.Trainer(classify_net.collect_params(),
trainer = mx.gluon.Trainer(target_params,
'adamw',
optimizer_params)

Expand Down Expand Up @@ -376,16 +416,22 @@ def train(args):
log_gnorm = 0
log_step = 0
if local_rank == 0 and (i == max_update - 1 or i%(max_update//args.epochs) == 0 and i>0):
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1))
ckpt_name = '{}_{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1),
args.method)

tmp_params = classify_net._collect_params_with_prefix()
params_saved = os.path.join(detail_dir, ckpt_name)
classify_net.save_parameters(params_saved)
arg_dict = {key: tmp_params[key]._reduce() for key in target_params}
_mx_npx.savez(params_saved, **arg_dict)
logging.info('Params saved in: {}'.format(params_saved))
for metric in metrics:
metric.reset()

end_time = time.time()
logging.info('Total costs:{}'.format(end_time - start_time))



def evaluate(args):
Expand All @@ -410,19 +456,24 @@ def evaluate(args):
str(ctx_l)))

cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)

candidate_ckpt = []
detail_dir = os.path.join(args.output_dir, args.task_name)
for name in os.listdir(detail_dir):
if name.endswith('.params') and args.task_name in name and args.model_name in name:
if name.endswith(args.method + '.params') and args.task_name in name and args.model_name in name:
candidate_ckpt.append(os.path.join(detail_dir, name))
candidate_ckpt.sort(reverse=False)
best_ckpt = {}
metrics = task.metric
def evaluate_by_ckpt(ckpt_name, best_ckpt):
classify_net.load_parameters(ckpt_name, ctx=ctx_l, cast_dtype=True)
loaded = _mx_npx.load(ckpt_name)
full_dict = {'params': loaded, 'filename': ckpt_name}
classify_net.load_dict(full_dict, ctx_l, allow_missing=True,
ignore_extra=True, cast_dtype=True)
logging.info('Prepare dev data')

dev_data, label = get_task_data(args, task, tokenizer, segment='eval')
Expand Down
87 changes: 86 additions & 1 deletion src/gluonnlp/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Layers."""
__all__ = ['PositionalEmbedding', 'SinusoidalPositionalEmbedding',
'LearnedPositionalEmbedding', 'BucketPositionalEmbedding', 'AdaptiveEmbedding',
'PositionwiseFFN', 'ProjectedAdaptiveLogSoftmaxWithLoss']
'PositionwiseFFN', 'ProjectedAdaptiveLogSoftmaxWithLoss', 'AdapterModule']

import math
from collections import OrderedDict
Expand All @@ -28,6 +28,8 @@
import numpy as _np
from typing import Union, Optional, List, Dict
from .op import relative_position_bucket
#from .attention_cell import MultiHeadAttentionCell
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a circular import, as attention_cell also imports layers.

from .layers import SinusoidalPositionalEmbedding,\
BucketPositionalEmbedding,\
LearnedPositionalEmbedding

To solve this, two options are to either move SinusoidalPositionalEmbedding,
BucketPositionalEmbedding,
LearnedPositionalEmbedding out of the layers.py into a new file and change the import in attention_cell. Or you can move AdapterModule into a new file. You can also come up with other solutions




InitializerType = Optional[Union[mx.init.Initializer, str]]
Expand Down Expand Up @@ -478,6 +480,8 @@ def forward(self, positions):
return np.take(self.weight.data(), positions, axis=0, mode=self._mode)




@use_np
class BucketPositionalEmbedding(HybridBlock):
"""Divide the positional space into buckets and assign the relative positions within each
Expand Down Expand Up @@ -543,6 +547,8 @@ def __init__(self,
layer_norm_eps: float = 1E-5,
pre_norm: bool = False,
dtype='float32',
use_adapter='False',
adapter_config={},
**kwargs):
"""

Expand Down Expand Up @@ -570,6 +576,7 @@ def __init__(self,
self._dtype = dtype
self._pre_norm = pre_norm
self._use_gated_activation = use_gated_activation
self._use_adapter = use_adapter
self._kwargs = OrderedDict([
('units', units),
('hidden_size', hidden_size),
Expand Down Expand Up @@ -611,6 +618,8 @@ def __init__(self,
normalization=normalization,
epsilon=layer_norm_eps,
**kwargs)
if self._use_adapter:
self.adapter_layer_ffn = AdapterModule(in_units=units, adapter_config=adapter_config)

def forward(self, data):
"""
Expand All @@ -637,6 +646,8 @@ def forward(self, data):
out = self.activation_dropout_layer(out)
out = self.ffn_2(out)
out = self.dropout_layer(out)
if self._use_adapter:
out = self.adapter_layer_ffn(out)
out = out + residual
if not self._pre_norm:
out = self.layer_norm(out)
Expand Down Expand Up @@ -1007,3 +1018,77 @@ def forward(self, hidden, target):

def __repr__(self):
return _gen_repr_with_kwargs(self._kwargs, self.__class__.__name__)

@use_np
class AdapterModule(nn.HybridBlock):
def __init__(self, in_units:int, adapter_config:dict):
super().__init__()
self._adapter_config = adapter_config
self.base_adapter_stacks = nn.HybridSequential()
for name in adapter_config['task_names']:
self.base_adapter_stacks.add(get_adapter(adapter_config[name], in_units))
if adapter_config['adapter_fusion']:
self.adapter_fusion = AdapterFusion(adapter_config['adapter_fusion_config'], in_units)

def forward(self, data):
output = []
for base_adapter in self.base_adapter_stacks:
output.append(base_adapter(data))
if self._adapter_config['adapter_fusion']:
output = np.stack(output, axis=0)
output = self.adapter_fusion(output)
return output
else:
return output[0]





@use_np
def get_adapter(base_adapter_config, in_units):
if base_adapter_config['type'] == 'Basic':
return BasicAdapter(units=base_adapter_config['unit'], in_units=in_units)
else:
pass
##lxy: not finished


@use_np
class AdapterFusion(nn.HybridBlock):
def __init__(self, config, in_units):
self._config = config
self.query_proj = nn.Dense(in_units=in_units, units=in_units)
self.key_proj = nn.Dense(in_units=in_units, units=in_units)
self.value_proj = nn.Dense(in_units=in_units, units=in_units)
self.attention_cell = MultiHeadAttentionCell(query_units=in_units,
num_heads=1,
attention_dropout=0,
scaled=True)

def forward(self, query, key, value):
query = self.query_proj(query)
key = self.key_proj(key)
value = self.value_proj(value)
output = self.attention_cell(query, key, value)
return output

@use_np
class BasicAdapter(nn.HybridBlock):
def __init__(self, units: int, in_units: int):
super().__init__()
self._units = units
self.down_proj = nn.Dense(in_units=in_units,
units=units,
flatten=False)
self.activate = get_activation('gelu')
self.up_proj = nn.Dense(in_units=units,
units=in_units,
flatten=False)

def forward(self, data):
out = self.down_proj(data)
out = self.activate(out)
out = self.up_proj(out)
return out + data

Loading