Skip to content

Commit afe35d4

Browse files
committed
Merge branch 'EK100' of github.com:HaozheQi/LLaVA-NeXT into shaokai_dev
2 parents 52f6d8e + 083bfad commit afe35d4

File tree

5 files changed

+169
-12
lines changed

5 files changed

+169
-12
lines changed

.vscode/launch.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"request": "launch",
88
"module": "torch.distributed.run",
99
"env": {
10-
"CUDA_VISIBLE_DEVICES": "1,2",
10+
"CUDA_VISIBLE_DEVICES": "1,2,3",
1111
"OMP_NUM_THREADS": "8",
1212
"NCCL_IB_DISABLE": "0",
1313
"NCCL_IB_GID_INDEX": "3",
@@ -18,7 +18,7 @@
1818
"WANDB_API_KEY": "65aeda82a75f1eed29c8e9250b175fcc73dca0d7",
1919
},
2020
"args": [
21-
"--nproc_per_node=2",
21+
"--nproc_per_node=3",
2222
"--nnodes=1",
2323
"--node_rank=0",
2424
"--master_addr=127.0.0.1",
@@ -31,6 +31,7 @@
3131
// "--image_folder", "/mediaPFM/data/haozhe/onevision/llava_data",
3232
"--image_folder", "/mediaPFM/data/haozhe/onevision/llava_data/geo3k/",
3333
"--video_folder", "/mediaPFM/data/haozhe/onevision/llava_video",
34+
// "--video_folder", "/home/haozhe/kitchen/AVION/datasets",
3435
"--mm_tunable_parts", "mm_vision_tower,mm_mlp_adapter,mm_language_model",
3536
"--mm_vision_tower_lr", "2e-6",
3637
"--vision_tower", "google/siglip-so400m-patch14-384",

llava/train/train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from llava import conversation as conversation_lib
4646
from llava.model import *
4747
from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
48-
from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord
48+
from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord, process_EK100_video_with_decord
4949

5050
torch.multiprocessing.set_sharing_strategy("file_system")
5151

@@ -1152,9 +1152,13 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
11521152
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
11531153

11541154
elif "video" in sources[0]:
1155-
video_file = self.list_data_dict[i]["video"]
1155+
video_info = self.list_data_dict[i]["video"]
11561156
video_folder = os.path.join(self.data_args.video_folder, sources[0]['dataset_name'])
1157-
video_file = os.path.join(video_folder, video_file)
1157+
if 'EK100' in video_folder:
1158+
video_file = os.path.join(video_folder, video_info.split("-")[0], video_info.split("-")[1]+".MP4")
1159+
else:
1160+
video_file = os.path.join(video_folder, video_info)
1161+
11581162
suffix = video_file.split(".")[-1]
11591163
if not os.path.exists(video_file):
11601164
print("File {} not exist!".format(video_file))
@@ -1191,6 +1195,10 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
11911195
video.append(frame)
11921196
except IOError:
11931197
print(f"Failed to read frame at path: {frame_path}")
1198+
elif 'EK100' in video_file:
1199+
start_second = float(self.list_data_dict[i]['start_timestamp'])
1200+
end_second = float(self.list_data_dict[i]['end_timestamp'])
1201+
video, video_time, frame_time, num_frames_to_sample = process_EK100_video_with_decord(video_file, self.data_args, start_second, end_second, 15)
11941202
else:
11951203
video, video_time, frame_time, num_frames_to_sample = process_video_with_decord(video_file, self.data_args)
11961204

llava/utils.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,151 @@
2222
except ImportError:
2323
print("Please install pyav to use video processing functions.")
2424

25+
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
26+
frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid')
27+
if jitter:
28+
seg_size = float(end_frame - start_frame - 1) / num_segments
29+
shift = (np.random.rand(num_segments) - 0.5) * seg_size
30+
frame_ids += shift
31+
return frame_ids.astype(int).tolist()
32+
33+
# def get_video_reader(videoname, num_threads, fast_rrc, rrc_params, fast_rcc, rcc_params):
34+
# video_reader = None
35+
# if fast_rrc:
36+
# video_reader = VideoReader(
37+
# videoname,
38+
# num_threads=num_threads,
39+
# width=rrc_params[0], height=rrc_params[0],
40+
# use_rrc=True, scale_min=rrc_params[1][0], scale_max=rrc_params[1][1],
41+
# )
42+
# elif fast_rcc:
43+
# video_reader = VideoReader(
44+
# videoname,
45+
# num_threads=num_threads,
46+
# width=rcc_params[0], height=rcc_params[0],
47+
# use_rcc=True,
48+
# )
49+
# else:
50+
# video_reader = VideoReader(videoname, num_threads=num_threads)
51+
# return video_reader
52+
53+
# def video_loader(root, vid, ext, second, end_second,
54+
# chunk_len=300, fps=30, clip_length=32,
55+
# threads=1,
56+
# fast_rrc=False, rrc_params=(224, (0.5, 1.0)),
57+
# fast_rcc=False, rcc_params=(224, ),
58+
# jitter=False):
59+
# assert fps > 0, 'fps should be greater than 0'
60+
61+
# if chunk_len == -1:
62+
# vr = get_video_reader(
63+
# osp.join(root, '{}.{}'.format(vid, ext)),
64+
# num_threads=threads,
65+
# fast_rrc=fast_rrc, rrc_params=rrc_params,
66+
# fast_rcc=fast_rcc, rcc_params=rcc_params,
67+
# )
68+
# end_second = min(end_second, len(vr) / fps)
69+
70+
# # calculate frame_ids
71+
# frame_offset = int(np.round(second * fps))
72+
# total_duration = max(int((end_second - second) * fps), clip_length)
73+
# frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
74+
75+
# # load frames
76+
# assert max(frame_ids) < len(vr)
77+
# try:
78+
# frames = vr.get_batch(frame_ids).asnumpy()
79+
# except decord.DECORDError as error:
80+
# print(error)
81+
# frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
82+
83+
# return torch.from_numpy(frames.astype(np.float32))
84+
85+
# else:
86+
# chunk_start = int(second) // chunk_len * chunk_len
87+
# chunk_end = int(end_second) // chunk_len * chunk_len
88+
# while True:
89+
# video_filename = osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk_end, ext))
90+
# if not osp.exists(video_filename):
91+
# # print("{} does not exists!".format(video_filename))
92+
# chunk_end -= chunk_len
93+
# else:
94+
# vr = decord.VideoReader(video_filename)
95+
# end_second = min(end_second, (len(vr) - 1) / fps + chunk_end)
96+
# assert chunk_start <= chunk_end
97+
# break
98+
# # calculate frame_ids
99+
# frame_ids = get_frame_ids(
100+
# int(np.round(second * fps)),
101+
# int(np.round(end_second * fps)),
102+
# num_segments=clip_length, jitter=jitter
103+
# )
104+
# all_frames = []
105+
# # allocate absolute frame-ids into the relative ones
106+
# for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len):
107+
# rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids))
108+
# rel_frame_ids = [int(frame_id - chunk * fps) for frame_id in rel_frame_ids]
109+
# vr = get_video_reader(
110+
# osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk, ext)),
111+
# num_threads=threads,
112+
# fast_rrc=fast_rrc, rrc_params=rrc_params,
113+
# fast_rcc=fast_rcc, rcc_params=rcc_params,
114+
# )
115+
# try:
116+
# frames = vr.get_batch(rel_frame_ids).asnumpy()
117+
# except decord.DECORDError as error:
118+
# print(error)
119+
# frames = vr.get_batch([0] * len(rel_frame_ids)).asnumpy()
120+
# except IndexError:
121+
# print(root, vid, ext, second, end_second)
122+
# all_frames.append(frames)
123+
# if sum(map(lambda x: x.shape[0], all_frames)) == clip_length:
124+
# break
125+
# res = torch.from_numpy(np.concatenate(all_frames, axis=0).astype(np.float32))
126+
# assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids)
127+
# return res
128+
129+
def process_EK100_video_with_decord(video_file, data_args, start_second, end_second, chunk_len):
130+
fps = 30
131+
start_frame = int(start_second * fps)
132+
end_frame = int(end_second * fps)
133+
chunk_start = int(start_second) // chunk_len * chunk_len
134+
chunk_end = int(end_second) // chunk_len * chunk_len
135+
video_time = end_second - start_second
136+
while True:
137+
video_filename = os.path.join(video_file, '{}.MP4'.format(chunk_end))
138+
if not os.path.exists(video_filename):
139+
# print("{} does not exists!".format(video_filename))
140+
chunk_end -= chunk_len
141+
else:
142+
vr = VideoReader(video_filename, ctx=cpu(0), num_threads=1)
143+
end_second = min(end_second, (len(vr) - 1) / fps + chunk_end)
144+
assert chunk_start <= chunk_end
145+
break
146+
147+
# calculate frame_ids
148+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=data_args.frames_upbound, jitter=False)
149+
frame_time = [i/fps for i in frame_ids]
150+
151+
all_frames = []
152+
# allocate absolute frame-ids into the relative ones
153+
for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len):
154+
rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids))
155+
rel_frame_ids = [int(frame_id - chunk * fps) for frame_id in rel_frame_ids]
156+
vr = VideoReader(os.path.join(video_file, '{}.MP4'.format(chunk)),ctx=cpu(0), num_threads=1)
157+
frames = vr.get_batch(rel_frame_ids).asnumpy()
158+
all_frames.append(frames)
159+
vr.seek(0)
160+
if sum(map(lambda x: x.shape[0], all_frames)) == data_args.frames_upbound:
161+
break
162+
163+
video = np.concatenate(all_frames, axis=0).astype(np.float32)
164+
165+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
166+
num_frames_to_sample = len(frame_ids)
167+
168+
return video, video_time, frame_time, num_frames_to_sample
169+
25170
def process_video_with_decord(video_file, data_args):
26171
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
27172
total_frame_num = len(vr)

run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22

33
# Export environment variables
4-
export CUDA_VISIBLE_DEVICES="0,1"
4+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
55
export OMP_NUM_THREADS="8"
66
export NCCL_IB_DISABLE="0"
77
export NCCL_IB_GID_INDEX="3"
@@ -12,14 +12,14 @@ export ACCELERATE_CPU_AFFINITY="1"
1212
export WANDB_API_KEY="65aeda82a75f1eed29c8e9250b175fcc73dca0d7"
1313

1414
# Run the command using torchrun
15-
torchrun --nproc_per_node=2 \
15+
torchrun --nproc_per_node=4 \
1616
--nnodes=1 \
1717
--node_rank=0 \
1818
--master_addr=127.0.0.1 \
1919
--master_port=29500 \
2020
llava/train/train_mem.py \
2121
--deepspeed scripts/zero3.json \
22-
--model_name_or_path lmms-lab/llava-onevision-qwen2-7b-ov \
22+
--model_name_or_path lmms-lab/llava-onevision-qwen2-0.5b-ov \
2323
--version qwen_1_5 \
2424
--data_path scripts/train/onevision.yaml \
2525
--image_folder /media/data/haozhe/VFM/onevision/llava_data/geo3k/ \
@@ -60,4 +60,4 @@ torchrun --nproc_per_node=2 \
6060
--torch_compile True \
6161
--torch_compile_backend inductor \
6262
--dataloader_drop_last True \
63-
--frames_upbound 32 > test7b.out 2>&1
63+
--frames_upbound 32 > train_kitchen0.5b.out 2>&1

scripts/train/onevision.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ datasets:
6868
# - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json
6969
# sampling_strategy: "all"
7070
# - json_path: /media/data/haozhe/VFM/onevision/llava_instruct/geo3k.json
71-
- json_path: /mediaPFM/data/haozhe/onevision/llava_instruct/geo3k.json
72-
sampling_strategy: "all"
71+
# - json_path: /mediaPFM/data/haozhe/onevision/llava_instruct/geo3k.json
72+
# sampling_strategy: "all"
7373
# - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json
7474
# sampling_strategy: "first:10%"
7575
# - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json
@@ -183,4 +183,7 @@ datasets:
183183
# - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/0718_0_30_s_academic_mc_v0_1_all.json # will be released in next version of LLaVA-NeXT-Video
184184
# sampling_strategy: all
185185
# - json_path: /media/data/haozhe/VFM/onevision/llava_instruct/sharegpt4video.json # download from sharegpt4video
186-
# sampling_strategy: all
186+
# - json_path: /mediaPFM/data/haozhe/onevision/llava_instruct/sharegpt4video.json
187+
# sampling_strategy: "first:10%"
188+
- json_path: /media/data/haozhe/VFM/onevision/llava_instruct/train_convs_narration.jsonl
189+
sampling_strategy: all

0 commit comments

Comments
 (0)