Skip to content

Commit bda930e

Browse files
committed
cli
1 parent a15431e commit bda930e

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

.travis.yml

Lines changed: 0 additions & 7 deletions
This file was deleted.

bert_score_cli/score.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def main():
1717
help='two-letter abbreviation of the language (e.g., en) or "en-sci" for scientific text',
1818
)
1919
parser.add_argument(
20-
"-m", "--model", default=None, help="BERT model name (default: bert-base-uncased) or path to a pretrain model"
20+
"-m", "--model", default=None, help="BERT model name (default: codebert-base) or path to a pretrain model"
2121
)
2222
parser.add_argument("-l", "--num_layers", type=int, default=None, help="use first N layer in BERT (default: 8)")
2323
parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size (default: 64)")
@@ -33,6 +33,12 @@ def main():
3333
parser.add_argument(
3434
"-c", "--cand", type=str, required=True, help="candidate (system outputs) file path or a string"
3535
)
36+
parser.add_argument("--no_punc", action="store_true", help="exclude punctuation-only tokens in candidate and reference")
37+
parser.add_argument(
38+
"--sources", type=str, required=True, help="a list of a source for each candidate, to be concatenated"
39+
"with the candidates but removed from the similarity computation"
40+
)
41+
parser.add_argument("--chunk_overlap", type=float, default=0.5, help="how much overlap between chunks, when the input is longer than the models' max length")
3642

3743
args = parser.parse_args()
3844

@@ -67,17 +73,21 @@ def main():
6773
return_hash=True,
6874
rescale_with_baseline=args.rescale_with_baseline,
6975
baseline_path=args.baseline_path,
76+
no_punc=args.no_punc,
77+
sources=args.sources,
78+
chunk_overlap=args.chunk_overlap,
7079
)
7180
avg_scores = [s.mean(dim=0) for s in all_preds]
7281
P = avg_scores[0].cpu().item()
7382
R = avg_scores[1].cpu().item()
7483
F1 = avg_scores[2].cpu().item()
75-
msg = hash_code + f" P: {P:.6f} R: {R:.6f} F1: {F1:.6f}"
84+
F3 = avg_scores[3].cpu().item()
85+
msg = hash_code + f" P: {P:.6f} R: {R:.6f} F1: {F1:.6f} F3: {F3:.6f}"
7686
print(msg)
7787
if args.seg_level:
78-
ps, rs, fs = all_preds
79-
for p, r, f in zip(ps, rs, fs):
80-
print("{:.6f}\t{:.6f}\t{:.6f}".format(p, r, f))
88+
ps, rs, fs, f3s = all_preds
89+
for p, r, f, f3 in zip(ps, rs, fs, f3s):
90+
print("{:.6f}\t{:.6f}\t{:.6f}\t{:.6f}".format(p, r, f, f3))
8191

8292

8393
if __name__ == "__main__":

example.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,16 @@ def print_results(predictions, refs, pred_results):
5656
with open('idf_dicts/java_idf.pkl', 'rb') as f:
5757
java_idf = pickle.load(f)
5858

59-
pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python")
60-
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf)
61-
print_results(predictions, refs, pred_results)
59+
# pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python")
60+
# pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf)
61+
# print_results(predictions, refs, pred_results)
6262

63-
print('When providing the context: "find the index of target in this.elements"')
64-
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2)
65-
print_results(predictions, refs, pred_results)
63+
# print('When providing the context: "find the index of target in this.elements"')
64+
# pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2)
65+
# print_results(predictions, refs, pred_results)
66+
67+
68+
with open('idf_dicts/python_idf.pkl', 'rb') as f:
69+
python_idf = pickle.load(f)
70+
pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], no_punc=True, lang='python', idf=python_idf)
71+
print(pred_results)

0 commit comments

Comments
 (0)