Skip to content

Commit 028f9ac

Browse files
author
tianxin
authored
Support multi-card evaluation (#410)
* add SemanticIndexing examples * 1. delete shell script 2. refator sub-directory * 1. use data.py instead of utils.py 2. calculate cosine_sim on GPU not CPU * 1. delete python api in Model forward function for dynamic to static conversion 2. add margin parameter * use data_set for predict.py * use Dataset to load evaluation data when run_ann.py * 1. set some command argument to required 2. rename some variable * add BatchNegaive strategy * delete .gitkeep * add ance strategy * delete unused arguments * delete .gitkeep * set data.py as common module of strategys * 1. add SemanticIndexBase 2. move common fucntions to data.py * mv build_index to ann_util.py * 1. mv train.py to top directory * get emb_size from model parameters * support reduce embedding size by command-line argument * handle illegal data * add ClipGradByGlobalNorm for train_ance.py * add README.md * update README * Update README * Update README * Update README * upload model and data to bos * update README * update comment * delete unused blanks * mend * add parameter description to README * update README * delete numpy in model.py * mend * mend * 1. support multi-card evaluation 2. fix encoding for windows
1 parent 97b52bc commit 028f9ac

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

examples/semantic_indexing/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
以下模型结构参数为:
1919
`TrasformerLayer:12, Hidden:768, Heads:12, OutputEmbSize: 256`
2020

21-
|Model|训练参数配置|MD5|
22-
| ------------ | ------------ | ------------ |
23-
|[batch_neg_v1.0](https://paddlenlp.bj.bcebos.com/models/semantic_index/batch_neg_v1.0.tar)|<div style="width: 200pt">margin:0.2 scale:30 epoch:3 lr:5E-5</div>|da1bb1487bd3fd6a53b8ef95c278f3e6|
24-
|[hardest_neg_v1.0](https://paddlenlp.bj.bcebos.com/models/semantic_index/hardest_neg_v1.0.tar)|margin:0.2 epoch:3 lr:5E-5|b535d890110ea608c8562c525a0b84b5|
21+
|Model|训练参数配置|硬件|MD5|
22+
| ------------ | ------------ | ------------ |-----------|
23+
|[batch_neg_v1.0](https://paddlenlp.bj.bcebos.com/models/semantic_index/batch_neg_v1.0.tar)|<div style="width: 150pt">margin:0.2 scale:30 epoch:3 lr:5E-5 bs:128 max_len:64 </div>|<div style="width: 100pt">单卡v100-16g</div>|da1bb1487bd3fd6a53b8ef95c278f3e6|
24+
|[hardest_neg_v1.0](https://paddlenlp.bj.bcebos.com/models/semantic_index/hardest_neg_v1.0.tar)|margin:0.2 epoch:3 lr:5E-5 bs:128 max_len:64 |单卡v100-16g|b535d890110ea608c8562c525a0b84b5|
2525

2626

2727
## 数据准备

examples/semantic_indexing/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def valid_checkpoint(step):
155155

156156
def gen_id2corpus(corpus_file):
157157
id2corpus = {}
158-
with open(corpus_file) as f:
158+
with open(corpus_file, 'r', encoding='utf-8') as f:
159159
for idx, line in enumerate(f):
160160
id2corpus[idx] = line.rstrip()
161161
return id2corpus
@@ -164,7 +164,7 @@ def gen_id2corpus(corpus_file):
164164
def gen_text_file(similar_text_pair_file):
165165
text2similar_text = {}
166166
texts = []
167-
with open(similar_text_pair_file) as f:
167+
with open(similar_text_pair_file, 'r', encoding='utf-8') as f:
168168
for line in f:
169169
splited_line = line.rstrip().split("\t")
170170
if len(splited_line) != 2:

examples/semantic_indexing/evaluate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ def recall(rs, N=10):
5252

5353
if __name__ == "__main__":
5454
text2similar = {}
55-
with open(args.similar_text_pair) as f:
55+
with open(args.similar_text_pair, 'r', encoding='utf-8') as f:
5656
for line in f:
5757
text, similar_text = line.rstrip().split("\t")
5858
text2similar[text] = similar_text
5959

6060
rs = []
6161

62-
with open(args.recall_result_file) as f:
62+
with open(args.recall_result_file, 'r', encoding='utf-8') as f:
6363
relevance_labels = []
6464
for index, line in enumerate(f):
6565

@@ -77,7 +77,6 @@ def recall(rs, N=10):
7777

7878
recall_N = []
7979
for topN in (10, 50):
80-
#logger.info("Recall@{}: {}".format(topN, 100 * recall(rs, N=topN)))
8180
R = round(100 * recall(rs, N=topN), 3)
8281
recall_N.append(str(R))
8382
print("\t".join(recall_N))

examples/semantic_indexing/recall.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858

5959
if __name__ == "__main__":
6060
paddle.set_device(args.device)
61+
rank = paddle.distributed.get_rank()
62+
if paddle.distributed.get_world_size() > 1:
63+
paddle.distributed.init_parallel_env()
6164

6265
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
6366

@@ -76,8 +79,9 @@
7679

7780
model = SemanticIndexBase(
7881
pretrained_model, output_emb_size=args.output_emb_size)
82+
model = paddle.DataParallel(model)
7983

80-
# load pretrained semantic model
84+
# Load pretrained semantic model
8185
if args.params_path and os.path.isfile(args.params_path):
8286
state_dict = paddle.load(args.params_path)
8387
model.set_dict(state_dict)
@@ -99,7 +103,10 @@
99103
batchify_fn=batchify_fn,
100104
trans_fn=trans_func)
101105

102-
final_index = build_index(args, corpus_data_loader, model)
106+
# Need better way to get inner model of DataParallel
107+
inner_model = model._layers
108+
109+
final_index = build_index(args, corpus_data_loader, inner_model)
103110

104111
text_list, text2similar_text = gen_text_file(args.similar_text_pair_file)
105112

@@ -112,14 +119,14 @@
112119
batchify_fn=batchify_fn,
113120
trans_fn=trans_func)
114121

115-
query_embedding = model.get_semantic_embedding(query_data_loader)
122+
query_embedding = inner_model.get_semantic_embedding(query_data_loader)
116123

117124
if not os.path.exists(args.recall_result_dir):
118125
os.mkdir(args.recall_result_dir)
119126

120127
recall_result_file = os.path.join(args.recall_result_dir,
121128
args.recall_result_file)
122-
with open(recall_result_file, 'w') as f:
129+
with open(recall_result_file, 'w', encoding='utf-8') as f:
123130
for batch_index, batch_query_embedding in enumerate(query_embedding):
124131
recalled_idx, cosine_sims = final_index.knn_query(
125132
batch_query_embedding.numpy(), args.recall_num)

0 commit comments

Comments
 (0)