Skip to content

Commit 819d73a

Browse files
author
Ye Shaokai
committed
refactored the code for consistent prompt
1 parent 6ff8a9f commit 819d73a

File tree

10 files changed

+108
-58
lines changed

10 files changed

+108
-58
lines changed
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import cv2
1212
from pathlib import Path
1313
from tqdm import tqdm
14-
from action.prediction_analysis import PredictionAnalysis
14+
from llava.action.prediction_analysis import PredictionAnalysis
1515

1616
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
1717

@@ -145,6 +145,15 @@ def extract_frames(self, vid_path, start_second, end_second):
145145
return frames, time_meta
146146

147147

148+
class GPTDataClenaer(ChatGPT):
149+
"""
150+
To clean the training annotation
151+
Instead of using the first verb appeared in the verb csv, we use the csv file to
152+
have chatgpt select the best ones.
153+
We also inject rules to correct some confusing convention of how EK100 names verbs
154+
"""
155+
156+
148157
class GPTInferenceAnnotator(ChatGPT):
149158
"""
150159
Given the images, this class will annotate the video frames
Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,16 @@
1313
from pathlib import Path
1414
import sys
1515
import os
16-
sys.path[0] = os.path.dirname(sys.path[0])
17-
from action.llava_ov_inference import llava_inference
16+
from llava.action.llava_inference import llava_inference
1817
import json
1918
import logging
2019
from llava.utils import rank0_print
21-
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions, avion_video_loader, create_multi_choice_from_avion_predictions
22-
from action.prediction_analysis import PredictionAnalysis
23-
import copy
20+
from llava.action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions, avion_video_loader, create_multi_choice_from_avion_predictions
21+
from llava.action.prediction_analysis import PredictionAnalysis
2422
from collections import Counter
2523
import torch.distributed as dist
2624

2725

28-
29-
3026
def setup(rank, world_size):
3127
# Check if the process group is already initialized
3228
if not dist.is_initialized():
@@ -229,7 +225,6 @@ def get_args_parser():
229225
parser.add_argument('--use-multi-epochs-loader', action='store_true')
230226

231227
# llava related
232-
# llm size is type of string and can only be '7b' or '5b' etc.
233228
parser.add_argument('--pretrained_name', default = '', type = str, help ='the name in huggingface')
234229
parser.add_argument('--llava_num_frames', default=16, type=int, help='number of frames for llava')
235230
## avion refinement
@@ -467,23 +462,7 @@ def evaluate_on_EK100(eval_args,
467462
logger.info(f'Process {dist.get_rank()} - local_total_samples: {local_total_samples:.4f}')
468463
logger.info(f'Process {dist.get_rank()} - loca_llava_correct: {llava_correct:.4f}')
469464
logger.info(f'Process {dist.get_rank()} - local_running_corrects: {local_running_corrects:.4f}')
470-
471-
472-
# Calculate and log running mean accuracy
473-
# dist.barrier()
474-
# dist.all_reduce(local_running_corrects, op=dist.ReduceOp.SUM)
475-
# dist.all_reduce(local_total_samples, op=dist.ReduceOp.SUM)
476-
# if eval_args.action_predictions:
477-
# dist.all_reduce(local_avion_correct, op=dist.ReduceOp.SUM)
478-
# dist.barrier()
479-
# # Calculate global accuracy after reduction
480-
# local_running_accuracy = local_running_corrects.item() / local_total_samples.item()
481-
# local_avion_accuracy = local_avion_correct.item() / local_total_samples.item()
482-
483-
# logger.info(f'Process {dist.get_rank()} - Running accuracy: {local_running_accuracy:.4f}')
484-
# logger.info(f'Process {dist.get_rank()} - AvionRunning accuracy: {local_avion_accuracy:.4f}')
485-
486-
465+
487466

488467
dist.barrier()
489468
dist.all_reduce(global_running_corrects, op=dist.ReduceOp.SUM)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import csv
44
import os
55
import argparse
6-
from action.utils import generate_label_map, MultiChoiceGenerator, AvionMultiChoiceGenerator
6+
import sys
7+
from llava.action.utils import generate_label_map, MultiChoiceGenerator, AvionMultiChoiceGenerator, format_task_related_prompt
78
from pathlib import Path
89

910

@@ -78,7 +79,7 @@ def generate_naive_conversation(vn_str:str):
7879

7980
def generate_random_mc_conversation(options:list[str], gt_answer_letter, gt_answer_name):
8081
return [
81-
{"from": "human", "value": f"<image>\n the video is taken from egocentric view. What action is the person performing? Please select the letter for the right answer {options}"},
82+
{"from": "human", "value": f"{options}"},
8283
{"from": "gpt", "value": f"{gt_answer_letter}. {gt_answer_name}"}
8384
]
8485

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import numpy as np
77
import copy
8+
from llava.action.utils import format_llava_prompt
89

910

1011
def llava_ov_process(video_frames,
@@ -31,7 +32,7 @@ def llava_ov_process(video_frames,
3132

3233
question = mc_data['question'][0]
3334
options = mc_data['options'][0]
34-
35+
3536
question = f"{DEFAULT_IMAGE_TOKEN}\n{question}:{options}"
3637

3738
conv = copy.deepcopy(conv_templates[conv_template])
@@ -82,21 +83,21 @@ def llava_video_process(
8283

8384
video_duration = time_meta['duration'].item()
8485
n_frames = time_meta['n_frames'].item()
85-
frame_time = time_meta['frame_time']
86-
print ('frame time', frame_time)
87-
frame_time = frame_time[0]
88-
time_instruciton = f"You are seeing a video taken from egocentric view. The video lasts for {video_duration:.2f} seconds, and {n_frames} frames are uniformly sampled from it. What is the person doing? Format your answer letter. verb noun such as A. move knife."
8986

9087
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)
9188

9289
image_tensors.append(frames)
9390

9491
conv_template = "qwen_1_5"
9592

96-
question = mc_data['question'][0]
9793
options = mc_data['options'][0]
9894

99-
question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\n:{options}"
95+
question = format_llava_prompt(DEFAULT_IMAGE_TOKEN,
96+
options,
97+
video_duration,
98+
n_frames,
99+
include_frame_time = True,
100+
include_time_instruction= True)
100101

101102
print ('what is the question')
102103
print (question)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def analysis(self):
169169
if __name__ == '__main__':
170170

171171
# at rcp server
172-
#save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT/llavavideo_avion_mc_top10_5epoch_preds'
172+
save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT/llavavideo_avion_mc_top10_5epoch_preds_without_frame_time'
173173
# at amg0
174-
save_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
174+
#save_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
175175

176176

177177
prediction_analysis = PredictionAnalysis(save_folder = save_folder,

action/utils.py renamed to llava/action/utils.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,53 @@ def generate_label_map(anno_root):
4343
return labels, mapping_vn2act, verb_maps, noun_maps
4444

4545

46+
def format_task_related_prompt(option_list):
47+
prefix = "The video is taken from egocentric view. What action is the person performing? Given multiple choices, format your answer as the 'option letter. option_name' such as 'A. move knife' where A is the option letter and knife is the option_name.\n"
48+
assert isinstance(option_list, list)
49+
suffix = ",".join(option_list)
50+
suffix = "Here are the options you are tasked:\n" + suffix
51+
ret = prefix + suffix
52+
return ret
53+
54+
def format_time_instruction(video_duration, n_frames, include_frame_time = False):
55+
56+
prefix = f"You are seeing a video taken from egocentric view. The video lasts for {video_duration:.2f} seconds, and {n_frames} frames are uniformly sampled from it."
57+
58+
frame_time = [i * (video_duration / n_frames) for i in range(n_frames)]
59+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
60+
61+
suffix = ""
62+
if include_frame_time:
63+
suffix = f"These frames are located at {frame_time}."
64+
65+
return prefix + suffix
66+
67+
68+
def format_llava_prompt(image_token,
69+
option_list,
70+
video_duration,
71+
n_frames,
72+
include_time_instruction = False,
73+
include_frame_time = False
74+
):
75+
"""
76+
baseline llava prompt: {image_token}\n{task_related_prompt}
77+
with time instruction: {image_token}\n{time_instruction}\n{task_related_prompt}
78+
79+
"""
80+
task_related_prompt = format_task_related_prompt(option_list)
81+
time_instruction = format_time_instruction(video_duration, n_frames, include_frame_time)
82+
83+
if include_time_instruction:
84+
ret = f"{image_token}\n{time_instruction}{task_related_prompt}"
85+
else:
86+
ret = f"{image_token}\n{task_related_prompt}"
87+
88+
return ret
89+
4690
def match_answer(pred, gt):
4791
return pred == gt
4892

49-
5093
def parse_avion_predictions(predictions):
5194
return [pred.replace(':', ' ', 1) for pred in predictions]
5295

@@ -90,7 +133,6 @@ def generate_multi_choice(self, gt_vn, k):
90133

91134
gt_letter = letters[answers.index(gt_answer)]
92135
data = {
93-
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
94136
'options': {0: options},
95137
# the correct letter in mc
96138
# for inspecting
@@ -142,8 +184,8 @@ def generate_multi_choice(self, gt_vn, avion_predictions, k):
142184

143185
gt_letter = letters[answers.index(gt_answer)]
144186

187+
145188
data = {
146-
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
147189
'options': {0: options},
148190
# the correct letter in mc
149191
# for inspecting
@@ -194,7 +236,6 @@ def create_multi_choice_from_avion_predictions(avion_predictions, k):
194236
options[i] = f'{letters[i]}. {predictions[i]}'
195237

196238
mc_data = {
197-
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer.'},
198239
'options': {0: options},
199240
'valid_letters': letters,
200241
'avion_pred': predictions[0]

llava/train/llava_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from transformers.trainer_pt_utils import AcceleratorConfig
1818
from typing import List, Optional
1919
from datetime import timedelta
20+
import llava
21+
from llava.action.ek_eval import evaluate_on_EK100
2022

2123
if is_accelerate_available():
2224
from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
@@ -248,8 +250,7 @@ def __init__(self, *args, tokenizer = None, eval_args = None, model_max_length =
248250

249251

250252

251-
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
252-
from action.ek_eval import evaluate_on_EK100
253+
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
253254

254255
accuracy = evaluate_on_EK100(self.eval_args, self.model, self.tokenizer)
255256

llava/train/train.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
import transformers
3737
import tokenizers
3838
import deepspeed
39-
39+
import sys
40+
import llava
4041
from transformers import AutoConfig
4142
from torch.utils.data import Dataset
4243
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
@@ -47,6 +48,7 @@
4748
from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
4849
from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord, process_EK100_video_with_decord
4950

51+
from llava.action.utils import format_llava_prompt
5052

5153
torch.multiprocessing.set_sharing_strategy("file_system")
5254

@@ -978,10 +980,11 @@ def get_tokenize_len(prompts):
978980

979981

980982
class LazySupervisedDataset(Dataset):
981-
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
983+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, eval_args):
982984
super(LazySupervisedDataset, self).__init__()
983985
self.tokenizer = tokenizer
984-
self.list_data_dict = []
986+
self.list_data_dict = []
987+
self.eval_args = eval_args
985988

986989
# Handle multiple JSON files specified in the data_path
987990
if "{" in data_path and "}" in data_path:
@@ -1231,9 +1234,24 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
12311234

12321235
processor = self.data_args.image_processor
12331236
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
1234-
if self.data_args.add_time_instruction:
1235-
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. Please answer the following questions related to this video."
1236-
sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}'
1237+
if 'EK100' not in video_file:
1238+
if self.data_args.add_time_instruction:
1239+
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. Please answer the following questions related to this video."
1240+
sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}'
1241+
else:
1242+
# We use our own prompting logic when it's EK100
1243+
options = eval(sources[0]["conversations"][0]["value"])
1244+
assert isinstance(options, list)
1245+
assert len(options) == self.eval_args.topk_predictions
1246+
# We only store the option list in the annotation file to make it easier to use consistent prompting
1247+
llava_prompt = format_llava_prompt(DEFAULT_IMAGE_TOKEN,
1248+
options,
1249+
video_time,
1250+
num_frames_to_sample,
1251+
include_time_instruction= self.data_args.add_time_instruction,
1252+
include_frame_time = True)
1253+
sources[0]["conversations"][0]["value"] = llava_prompt
1254+
12371255
image = [(image, video[0].size, "video")]
12381256
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
12391257
# print(sources)
@@ -1322,9 +1340,9 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
13221340
return batch
13231341

13241342

1325-
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
1343+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, eval_args) -> Dict:
13261344
"""Make dataset and collator for supervised fine-tuning."""
1327-
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
1345+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args, eval_args = eval_args)
13281346
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
13291347
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
13301348

@@ -1728,7 +1746,7 @@ def make_inputs_require_grad(module, input, output):
17281746
if training_args.bf16 and module.weight.dtype == torch.float32:
17291747
module = module.to(torch.bfloat16)
17301748

1731-
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
1749+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, eval_args=eval_args)
17321750

17331751
eval_args.pretrained_name = model_args.model_name_or_path.split('/')[1]
17341752

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
datasets:
2-
# - json_path: /data/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl
3-
- json_path: /capstor/scratch/cscs/hqi/llava/EK100/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl
2+
- json_path: /data/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl
3+
#- json_path: /capstor/scratch/cscs/hqi/llava/EK100/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl
44
sampling_strategy: all

shaokai_generate_train.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
python3 action/generate_description.py \
2-
--train_metadata /data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_train.csv \
3-
--out_folder /data/EK100_inst_train/ \
4-
--avion_train_predictions /data/epic_kitchen/avion_predictions_train.json \
1+
python3 llava/action/generate_description.py \
2+
--train_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
3+
--out_folder /data/shaokai/EK100_inst_train/ \
4+
--avion_train_predictions /data/shaokai/avion_predictions_train.json \
55
--gen_type avion_mc \
66
--n_options 10
77

0 commit comments

Comments
 (0)