Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions scripts/train/cspider_text2sql/train_model_multi_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# train text2natsql-mt5-xl-cspider model
python -m torch.distributed.launch --nproc_per_node=3 --nnodes=1 text2sql.py \
--batch_size 4 \
--gradient_descent_step 4 \
--device "1,2,3" \
--learning_rate 5e-5 \
--epochs 512 \
--seed 42 \
--save_path "./models/myTrain/text2sql-mt5-large-medical-cspider-multigpu-better-data" \
--tensorboard_save_path "./tensorboard_log/text2sql-mt5-large-medical-spider-multigpu-better-data" \
--model_name_or_path "./models/mt5-large-raw" \
--dev_filepath "./data/Medical/preprocessed_data/train/resdsql_dev_medical_cspider_natsql.json" \
--use_adafactor \
--mode "train" \
--db_path "./data/Medical/database" \
--train_filepath "./data/Medical/preprocessed_data/train/resdsql_train_medical_cspider_natsql.json" \
--original_dev_filepath "./data/Medical/dev_medical_cspider.json"
115 changes: 74 additions & 41 deletions text2sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import json
import random
import torch
import argparse
import torch.optim as optim
Expand All @@ -16,6 +17,9 @@
from utils.spider_metric.evaluator import EvaluateTool
from utils.load_dataset import Text2SQLDataset
from utils.text2sql_decoding_utils import decode_sqls, decode_natsqls
# 分布式改造
from torch.utils.data.distributed import DistributedSampler
import numpy as np

def parse_option():
parser = argparse.ArgumentParser("command line arguments for fine-tuning pre-trained language model.")
Expand Down Expand Up @@ -67,6 +71,8 @@ def parse_option():
help = "sql or natsql.")
parser.add_argument("--output", type = str, default = "predicted_sql.txt",
help = "save file of the predicted sqls.")
# 分布式改造
parser.add_argument("--local_rank", type=int, default=-1)

opt = parser.parse_args()

Expand All @@ -76,12 +82,29 @@ def _train(opt):
set_seed(opt.seed)
print(opt)

if opt.tensorboard_save_path is not None:
writer = SummaryWriter(opt.tensorboard_save_path)
else:
writer = None
os.environ["CUDA_VISIBLE_DEVICES"] = opt.device #

# 各进程根据local_rank设置GPU
torch.cuda.set_device(opt.local_rank)
device = torch.device('cuda', opt.local_rank)

# 初始化分布式环境
torch.distributed.init_process_group(backend='nccl')

# 固定种子
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)

writer =None
if opt.local_rank == 0:
if opt.tensorboard_save_path is not None:
writer = SummaryWriter(opt.tensorboard_save_path)
else:
pass


os.environ["CUDA_VISIBLE_DEVICES"] = opt.device

text2sql_tokenizer = T5TokenizerFast.from_pretrained(
opt.model_name_or_path,
Expand All @@ -96,13 +119,15 @@ def _train(opt):
mode = "train"
)

train_dataloder = DataLoader(
train_dataset,
batch_size = opt.batch_size,
shuffle = True,
collate_fn = lambda x: x,
drop_last = True
)

# Dataloader进行分布式封装
train_sampler = DistributedSampler(train_dataset)
train_dataloder = torch.utils.data.DataLoader(train_dataset,
batch_size=opt.batch_size,
sampler=train_sampler,
collate_fn=lambda x: x,
drop_last=True)


model_class = MT5ForConditionalGeneration if "mt5" in opt.model_name_or_path else T5ForConditionalGeneration

Expand All @@ -111,7 +136,13 @@ def _train(opt):
model = model_class.from_pretrained(opt.model_name_or_path)
model.resize_token_embeddings(len(text2sql_tokenizer))
if torch.cuda.is_available():
model = model.cuda()
model = model.to(device)
# 分布式模型
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[opt.local_rank],
output_device=opt.local_rank,
find_unused_parameters=True)


print("finished.")

Expand All @@ -125,22 +156,22 @@ def _train(opt):
if opt.use_adafactor:
print("Let's use Adafactor!")
optimizer = Adafactor(
model.parameters(),
lr=opt.learning_rate,
scale_parameter=False,
relative_step=False,
model.parameters(),
lr=opt.learning_rate,
scale_parameter=False,
relative_step=False,
clip_threshold = 1.0,
warmup_init=False
)
else:
print("Let's use AdamW!")
optimizer = optim.AdamW(
model.parameters(),
model.parameters(),
lr = opt.learning_rate
)

scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,
optimizer,
num_warmup_steps = num_warmup_steps,
num_training_steps = num_training_steps
)
Expand All @@ -151,35 +182,35 @@ def _train(opt):
print(f"This is epoch {epoch+1}.")
for batch in train_dataloder:
train_step += 1

batch_inputs = [data[0] for data in batch]
batch_sqls = [data[1] for data in batch]
batch_db_ids = [data[2] for data in batch] # unused
batch_tc_original = [data[3] for data in batch] # unused

if epoch == 0:
for batch_id in range(len(batch_inputs)):
print(batch_inputs[batch_id])
print(batch_sqls[batch_id])
print("----------------------")

tokenized_inputs = text2sql_tokenizer(
batch_inputs,
batch_inputs,
padding = "max_length",
return_tensors = "pt",
max_length = 512,
truncation = True
)

with text2sql_tokenizer.as_target_tokenizer():
tokenized_outputs = text2sql_tokenizer(
batch_sqls,
padding = "max_length",
batch_sqls,
padding = "max_length",
return_tensors = 'pt',
max_length = 256,
truncation = True
)

encoder_input_ids = tokenized_inputs["input_ids"]
encoder_input_attention_mask = tokenized_inputs["attention_mask"]

Expand All @@ -188,19 +219,19 @@ def _train(opt):
decoder_attention_mask = tokenized_outputs["attention_mask"]

if torch.cuda.is_available():
encoder_input_ids = encoder_input_ids.cuda()
encoder_input_attention_mask = encoder_input_attention_mask.cuda()
decoder_labels = decoder_labels.cuda()
decoder_attention_mask = decoder_attention_mask.cuda()
encoder_input_ids = encoder_input_ids.to(device)
encoder_input_attention_mask = encoder_input_attention_mask.to(device)
decoder_labels = decoder_labels.to(device)
decoder_attention_mask = decoder_attention_mask.to(device)

model_outputs = model(
input_ids = encoder_input_ids,
attention_mask = encoder_input_attention_mask,
labels = decoder_labels,
decoder_attention_mask = decoder_attention_mask,
return_dict = True
)

loss = model_outputs["loss"]
loss.backward()

Expand All @@ -217,11 +248,13 @@ def _train(opt):
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()

if train_step % num_checkpoint_steps == 0 and epoch >= 6:
print(f"At {train_step} training step, save a checkpoint.")
os.makedirs(opt.save_path, exist_ok = True)
model.save_pretrained(save_directory = opt.save_path + "/checkpoint-{}".format(train_step))
if torch.distributed.get_rank()==0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(save_directory = opt.save_path + "/checkpoint-{}".format(train_step))
text2sql_tokenizer.save_pretrained(save_directory = opt.save_path + "/checkpoint-{}".format(train_step))

def _test(opt):
Expand Down Expand Up @@ -311,12 +344,12 @@ def _test(opt):
)
elif opt.target_type == "natsql":
predict_sqls += decode_natsqls(
opt.db_path,
model_outputs,
batch_db_ids,
batch_inputs,
tokenizer,
batch_tc_original,
opt.db_path,
model_outputs,
batch_db_ids,
batch_inputs,
tokenizer,
batch_tc_original,
table_dict
)
else:
Expand All @@ -341,7 +374,7 @@ def _test(opt):
spider_metric_result = evaluator.evaluate(predict_sqls)
print('exact_match score: {}'.format(spider_metric_result["exact_match"]))
print('exec score: {}'.format(spider_metric_result["exec"]))

return spider_metric_result["exact_match"], spider_metric_result["exec"]

if __name__ == "__main__":
Expand Down