Skip to content

Commit c992e88

Browse files
author
hanjian.thu123
committed
[bugfix] README typo & data path & compile flex attn
1 parent ea04127 commit c992e88

File tree

7 files changed

+5027
-5002
lines changed

7 files changed

+5027
-5002
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ weights
4444
checkpoints
4545
ref.py
4646
wandb
47+
.DS_Store

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Each "[h_div_w_template1]_[num_examples].jsonl" file contains lines of dumped js
113113
"long_caption": long caption of the image, required",
114114
"long_caption_type": "InternVL 2.0, required",
115115
"text": "short caption of the image, optional",
116-
"short_caption_type": "user prompt, , optional"
116+
"short_caption_type": "user prompt, optional"
117117
}
118118
```
119119

data/infinity_toy_data/splits/1.000_000002500.jsonl

Lines changed: 2500 additions & 2500 deletions
Large diffs are not rendered by default.

data/infinity_toy_data/splits/1.500_000002500.jsonl

Lines changed: 2500 additions & 2500 deletions
Large diffs are not rendered by default.

evaluation/gen_eval/rename.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import numpy as np
3+
4+
import json
5+
6+
7+
with open('/Users/bytedance/Desktop/projects/Infinity/evaluation/gen_eval/prompt_rewrite_cache_1.json', 'r') as f:
8+
correct = json.load(f)
9+
10+
with open('/Users/bytedance/Desktop/projects/Infinity/evaluation/gen_eval/prompt_rewrite_cache_123.json', 'r') as f:
11+
false_key_dict = json.load(f)
12+
13+
keys1_list = list(correct.keys())
14+
keys2_list = list(false_key_dict.keys())
15+
16+
final_dict = {}
17+
for i in range(len(keys1_list)):
18+
key1 = keys1_list[i]
19+
key2 = keys2_list[i]
20+
final_dict[key1] = false_key_dict[key2]
21+
22+
with open('prompt_rewrite_cache.json', 'w') as f:
23+
json.dump(final_dict, f, ensure_ascii=False, indent=2)

infinity/models/infinity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def compile_flex_attn(self):
330330
H = self.num_heads,
331331
L = aligned_L)
332332
attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
333-
return attn_fn_compile_dict
333+
return attn_fn_compile_dict
334334

335335
def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
336336
"""

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ einops
1919
openai
2020
httpx==0.20.0
2121
opencv-python
22+
flash_attn

0 commit comments

Comments
 (0)