Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion scripts/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,36 @@ here are some results with their hyperparameters
| CoLA | Matthew Corr. | 2e-5 | 32 | 7800 | 10 | 59.23 | https://tensorboard.dev/experiment/33euRGh9SrW3p15JWgILnw/ |
| RTE | Accuracy | 2e-5 | 32 | 1800 | 10 | 69.67 | https://tensorboard.dev/experiment/XjTxr5anRrC1LMukLJJQ3g/|
| MRPC | Accuracy/F1 | 3e-5 | 32 | 7800 | 5 | 85.38/87.31 | https://tensorboard.dev/experiment/jEJFq2XXQ8SvCxt6eKIjwg/ |
| MNLI | Accuracy(m/mm) | 2e-5 | 48 | 7800 | 5 | 84.90/85.10 | https://tensorboard.dev/experiment/CZQlOBedRQeTZwn5o5fbKQ/ |
| MNLI | Accuracy(m/mm) | 2e-5 | 48 | 7800 | 4 | 84.90/85.10 | https://tensorboard.dev/experiment/CZQlOBedRQeTZwn5o5fbKQ/ |


## different method
We also offer different finetune method to save time and space. So now we offer two different method:
bias-finetune() and adapter-finetune. To use them, you can directly add an augment "method" like:
```bash
python train_classification.py \
--model_name google_en_uncased_bert_base \
--method adapter \
--task_name mrpc \
--lr 4.5e-4\
--model_name google_en_cased_bert_base \
--batch_size 32 \
--do_train \
--do_eval \
--seed 7800 \
--epochs 10 \
--optimizer adamw \
--train_dir glue/mrpc/train.parquet \
--eval_dir glue/mrpc/dev.parquet \
--gpus 1
```
And here are some result of different method(the blank means we can't find proper hyperparameter until now)

| task Name | metirc | full | bias-finetune | adapter |
|-----------|-------------|-------------|-------------|-------------|
| SST | Accuracy | 93.23 | | 93.46 |
| STS | Pearson Corr. | 89.26 | 89.30 | 89.70 |
| CoLA | Matthew Corr. | 59.23 | | 61.20 |
| RTE | Accuracy | 69.67 | 69.31 | 70.75 |
| MRPC | Accuracy/F1 | 85.38/87.31 | 85.29/88.63 | 87.74/91.39|
| MNLI | Accuracy(m/mm) | 84.90/85.10 |
91 changes: 77 additions & 14 deletions scripts/classification/train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import json
import random
import pandas as pd
import mxnet.numpy_extension as _mx_npx
import os
import json
import logging
import time
import argparse
Expand Down Expand Up @@ -92,13 +94,36 @@ def parse_args():
help='the path to training dataset')
parser.add_argument('--warmup_ratio', type=float, default=0.1,
help='Ratio of warmup steps in the learning rate scheduler.')
parser.add_argument('--method', type=str, default='full', choices=['full', 'bias', 'adapter', 'last_layer'],
help='different finetune method')


args = parser.parse_args()
return args


def change_adapter_cfg(cfg, task):
adapter_config = {
'location_0':{
'adapter_fusion':False,
'pre_operator':False,
'task_names':[task.task_name],
task.task_name:{'type':'Basic','units':64, 'activation':'gelu'}},
'location_1':{
'adapter_fusion':False,
'pre_operator':False,
'task_names':[task.task_name],
task.task_name:{'type':'Basic','units':64, 'activation':'gelu'}}
}
cfg.defrost()
cfg.MODEL.use_adapter = True
cfg.MODEL.adapter_config = json.dumps(adapter_config)
cfg.freeze()
return cfg

def get_network(model_name,
ctx_l,
method='full',
checkpoint_path=None,
backbone_path=None,
task=None):
Expand All @@ -109,13 +134,15 @@ def get_network(model_name,
use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Model, cfg, tokenizer, download_params_path, _ = \
get_backbone(model_name, load_backbone=not backbone_path)

if method == 'adapter':
cfg = change_adapter_cfg(cfg, task)
backbone = Model.from_cfg(cfg)
# Load local backbone parameters if backbone_path provided.
# Otherwise, download backbone parameters from gluon zoo.

backbone_params_path = backbone_path if backbone_path else download_params_path
if checkpoint_path is None:
backbone.load_parameters(backbone_params_path, ignore_extra=True,
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=(args.method != 'full'),
ctx=ctx_l, cast_dtype=True)
num_params, num_fixed_params \
= count_parameters(deduplicate_param_dict(backbone.collect_params()))
Expand Down Expand Up @@ -219,20 +246,23 @@ def train(args):
#random seed
set_seed(args.seed)
level = logging.INFO
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
detail_dir = os.path.join(args.output_dir, args.task_name)
if not os.path.exists(detail_dir):
os.mkdir(detail_dir)
logging_config(detail_dir,
name='train_{}_{}_'.format(args.task_name, args.model_name) + str(rank), # avoid race
name='train_{}_{}_{}_'.format(args.task_name, args.model_name, args.method) + str(rank), # avoid race
level=level,
console=(local_rank == 0))
logging.info(args)
cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)


logging.info('Prepare training data')
train_data, _ = get_task_data(args, task, tokenizer, segment='train')
train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
Expand All @@ -253,6 +283,26 @@ def train(args):
sampler=sampler)


if args.method == 'full':
target_params_name = classify_net.collect_params().keys()
elif args.method == 'bias':
target_params_name = [key
for key in classify_net.collect_params() if
key.endswith('bias') or key.endswith('beta') or 'out_proj' in key]
elif args.method == 'adapter':
target_params_name = [key
for key in classify_net.collect_params() if
'adapter' in key or 'out_proj' in key]
elif args.method == 'last_layer':
target_params_name = [key
for key in classify_net.collect_params() if
'out_proj' in key]
for name in classify_net.collect_params():
if name not in target_params_name:
classify_net.collect_params()[name].grad_req = 'null'

target_params = {name:classify_net.collect_params()[name] for name in target_params_name}


param_dict = classify_net.collect_params()
# Do not apply weight decay to all the LayerNorm and bias
Expand All @@ -269,7 +319,7 @@ def train(args):
if local_rank == 0:
writer = SummaryWriter(logdir=os.path.join(args.output_dir,
args.task_name + '_tensorboard_' +
str(args.lr) + '_' + str(args.epochs)))
str(args.lr) + '_' + str(args.epochs) + '_' + str(args.method)))
if args.comm_backend == 'horovod':
# Horovod: fetch and broadcast parameters
hvd.broadcast_parameters(param_dict, root_rank=0)
Expand All @@ -290,10 +340,12 @@ def train(args):
optimizer_params = {'learning_rate': args.lr,
'wd': args.wd,
'lr_scheduler': lr_scheduler}


if args.comm_backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
trainer = hvd.DistributedTrainer(target_params, args.optimizer, optimizer_params)
else:
trainer = mx.gluon.Trainer(classify_net.collect_params(),
trainer = mx.gluon.Trainer(target_params,
'adamw',
optimizer_params)

Expand Down Expand Up @@ -376,16 +428,22 @@ def train(args):
log_gnorm = 0
log_step = 0
if local_rank == 0 and (i == max_update - 1 or i%(max_update//args.epochs) == 0 and i>0):
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1))
ckpt_name = '{}_{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1),
args.method)

tmp_params = classify_net._collect_params_with_prefix()
params_saved = os.path.join(detail_dir, ckpt_name)
classify_net.save_parameters(params_saved)
arg_dict = {key: tmp_params[key]._reduce() for key in target_params}
_mx_npx.savez(params_saved, **arg_dict)
logging.info('Params saved in: {}'.format(params_saved))
for metric in metrics:
metric.reset()

end_time = time.time()
logging.info('Total costs:{}'.format(end_time - start_time))



def evaluate(args):
Expand All @@ -410,19 +468,24 @@ def evaluate(args):
str(ctx_l)))

cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)

candidate_ckpt = []
detail_dir = os.path.join(args.output_dir, args.task_name)
for name in os.listdir(detail_dir):
if name.endswith('.params') and args.task_name in name and args.model_name in name:
if name.endswith(args.method + '.params') and args.task_name in name and args.model_name in name:
candidate_ckpt.append(os.path.join(detail_dir, name))
candidate_ckpt.sort(reverse=False)
best_ckpt = {}
metrics = task.metric
def evaluate_by_ckpt(ckpt_name, best_ckpt):
classify_net.load_parameters(ckpt_name, ctx=ctx_l, cast_dtype=True)
loaded = _mx_npx.load(ckpt_name)
full_dict = {'params': loaded, 'filename': ckpt_name}
classify_net.load_dict(full_dict, ctx_l, allow_missing=True,
ignore_extra=True, cast_dtype=True)
logging.info('Prepare dev data')

dev_data, label = get_task_data(args, task, tokenizer, segment='eval')
Expand Down
Loading