Skip to content

Commit d6844c5

Browse files
author
Aston Zhang
committed
Add code
1 parent fee8c35 commit d6844c5

File tree

11 files changed

+1365
-7
lines changed

11 files changed

+1365
-7
lines changed

README.md

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,78 @@
1-
## My Project
1+
# Multimodal Chain-of-Thought Reasoning in Language Models
22

3-
TODO: Fill this README out!
3+
Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output.
44

5-
Be sure to:
5+
![](vision_features/mm-cot.png)
66

7-
* Change the title in this README
8-
* Edit your repository description on GitHub
97

10-
## Security
8+
## Requirements
119

12-
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
10+
Install all required python dependencies:
11+
12+
```
13+
pip install -r requirements.txt
14+
```
15+
16+
## Datasets
17+
18+
Download the datasets from the following:
19+
20+
```
21+
https://github.com/lupantech/ScienceQA/tree/main/data
22+
```
23+
24+
Download the extracted vision fearures from [Anonymous](xxx) and unzip the files under `vision_features`
25+
26+
## Instructions
27+
28+
### Training
29+
30+
```
31+
# rationale generation
32+
CUDA_VISIBLE_DEVICES=0,1 python main.py \
33+
--model allenai/unifiedqa-t5-base \
34+
--user_msg rationale --img_type detr \
35+
--bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
36+
--final_eval --prompt_format QCM-LE
37+
38+
# answer inference
39+
CUDA_VISIBLE_DEVICES=0,1 python main.py \
40+
--model allenai/unifiedqa-t5-base \
41+
--user_msg answer --img_type detr \
42+
--bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
43+
--final_eval --prompt_format QCMG-A \
44+
--eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
45+
--test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
46+
```
47+
48+
### Inference
49+
50+
Our trained models are available at [Anonymous](xxx). To use our trained models, please put the them under the ```models``` folder.
51+
52+
```
53+
# rationale generation
54+
CUDA_VISIBLE_DEVICES=0,1 python main.py \
55+
--model allenai/unifiedqa-t5-base \
56+
--user_msg rationale --img_type detr \
57+
--bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
58+
--final_eval --prompt_format QCM-LE \
59+
--evaluate_dir models/rationale
60+
61+
# answer inference
62+
CUDA_VISIBLE_DEVICES=0,1 python main.py \
63+
--model allenai/unifiedqa-t5-base \
64+
--user_msg answer --img_type detr \
65+
--bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
66+
--final_eval --prompt_format QCMG-A \
67+
--eval_le models/rationale/predictions_ans_eval.json \
68+
--test_le models/rationale/predictions_ans_test.json \
69+
--evaluate_dir models/answer
70+
```
1371

1472
## License
1573

1674
This project is licensed under the Apache-2.0 License.
1775

76+
## Acknowledgement
77+
78+
Part of our codes are adapted from [ScienceQA](https://github.com/lupantech/ScienceQA) and [Transformers](https://github.com/huggingface/transformers).

evaluations.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
'''
2+
Adapted from https://github.com/lupantech/ScienceQA
3+
'''
4+
5+
import re
6+
from rouge import Rouge
7+
from nltk.translate.bleu_score import sentence_bleu
8+
from sentence_transformers import util
9+
10+
########################
11+
## BLEU
12+
########################
13+
def tokenize(text):
14+
tokens = re.split(r'\s|\.', text)
15+
tokens = [t for t in tokens if len(t) > 0]
16+
return tokens
17+
18+
19+
def bleu_score(reference, hypothesis, gram):
20+
reference_tokens = tokenize(reference)
21+
hypothesis_tokens = tokenize(hypothesis)
22+
23+
if gram == 1:
24+
bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1
25+
elif gram == 2:
26+
bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
27+
elif gram == 3:
28+
bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
29+
elif gram == 4:
30+
bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
31+
32+
return bleu
33+
34+
35+
def caculate_bleu(results, data, gram):
36+
bleus = []
37+
for qid, output in results.items():
38+
prediction = output
39+
target = data[qid]
40+
target = target.strip()
41+
if target == "":
42+
continue
43+
bleu = bleu_score(target, prediction, gram)
44+
bleus.append(bleu)
45+
46+
avg_bleu = sum(bleus) / len(bleus)
47+
48+
return avg_bleu
49+
50+
51+
########################
52+
## Rouge-L
53+
########################
54+
def score_rouge(str1, str2):
55+
rouge = Rouge(metrics=["rouge-l"])
56+
scores = rouge.get_scores(str1, str2, avg=True)
57+
rouge_l = scores['rouge-l']['f']
58+
return rouge_l
59+
60+
61+
def caculate_rouge(results, data):
62+
rouges = []
63+
for qid, output in results.items():
64+
prediction = output
65+
target = data[qid]
66+
target = target.strip()
67+
if prediction == "":
68+
continue
69+
if target == "":
70+
continue
71+
rouge = score_rouge(target, prediction)
72+
rouges.append(rouge)
73+
74+
avg_rouge = sum(rouges) / len(rouges)
75+
return avg_rouge
76+
77+
78+
########################
79+
## Sentence Similarity
80+
########################
81+
def similariry_score(str1, str2, model):
82+
# compute embedding for both lists
83+
embedding_1 = model.encode(str1, convert_to_tensor=True)
84+
embedding_2 = model.encode(str2, convert_to_tensor=True)
85+
score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86+
return score
87+
88+
89+
def caculate_similariry(results, data, model):
90+
scores = []
91+
for qid, output in results.items():
92+
prediction = output
93+
target = data[qid]
94+
target = target.strip()
95+
96+
score = similariry_score(target, prediction, model)
97+
scores.append(score)
98+
99+
avg_score = sum(scores) / len(scores)
100+
return avg_score

0 commit comments

Comments
 (0)