Skip to content

Commit 819a612

Browse files
add stream inference code
1 parent 245e3f4 commit 819a612

File tree

6 files changed

+403
-4
lines changed

6 files changed

+403
-4
lines changed

llava/model/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@
2424
from llava.utils import rank0_print
2525

2626

27-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
27+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
2828
kwargs["device_map"] = device_map
2929

3030
if load_8bit:
3131
kwargs["load_in_8bit"] = True
3232
elif load_4bit:
3333
kwargs["load_in_4bit"] = True
3434
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
35-
else:
35+
elif torch_dtype == "float16":
3636
kwargs["torch_dtype"] = torch.float16
37+
elif torch_dtype == "bfloat16":
38+
kwargs["torch_dtype"] = torch.bfloat16
39+
else:
40+
import pdb;pdb.set_trace()
3741

3842
if customized_config is not None:
3943
kwargs["config"] = customized_config

llava/model/llava_arch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def initialize_vision_modules(self, model_args, fsdp=None):
9393
self.config.mm_vision_select_feature = mm_vision_select_feature
9494
self.config.mm_patch_merge_type = mm_patch_merge_type
9595

96+
9697
if not hasattr(self.config, 'add_faster_video'):
9798
if model_args.add_faster_video:
9899
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
@@ -227,7 +228,7 @@ def add_token_per_grid(self, image_feature):
227228
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
228229
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
229230
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
230-
if self.config.add_faster_video:
231+
if getattr(self.config, "add_faster_video", False):
231232
# import pdb; pdb.set_trace()
232233
# (3584, 832, 14) -> (3584, 64, 13, 14)
233234
image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
@@ -311,7 +312,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
311312
if mm_newline_position == "grid":
312313
# Grid-wise
313314
image_feature = self.add_token_per_grid(image_feature)
314-
if self.config.add_faster_video:
315+
if getattr(self.config, "add_faster_video", False):
315316
faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
316317
# Add a token for each frame
317318
concat_slow_fater_token = []
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import numpy as np
2+
import cv2
3+
import warnings
4+
import select
5+
import sys
6+
import openai
7+
import base64
8+
9+
warnings.filterwarnings("ignore")
10+
11+
# Global variables for storing video frames and their respective times
12+
video_frames = []
13+
frame_times = []
14+
history_time = 0
15+
16+
17+
18+
client = openai.Client(api_key="EMPTY", base_url="xxx")
19+
20+
def encode_image(frames):
21+
base64_frames = []
22+
for frame in frames:
23+
# frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert BGR to RGB
24+
_, buffer = cv2.imencode(".jpg", frame)
25+
buffer = base64.b64encode(buffer).decode("utf-8")
26+
base64_frames.append(buffer)
27+
return base64_frames
28+
29+
# Function to send frames to the server and get a response
30+
def request_server(question, base64_frames):
31+
messages = [{"role": "user", "content": []}]
32+
for base64_frame in base64_frames:
33+
frame_format = {
34+
"type": "image_url",
35+
"image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"},
36+
"modalities": "video",
37+
}
38+
messages[0]["content"].append(frame_format)
39+
40+
prompt = {"type": "text", "text": question}
41+
messages[0]["content"].append(prompt)
42+
43+
video_request = client.chat.completions.create(
44+
model="llava-onevision-72b-ov",
45+
messages=messages,
46+
temperature=0,
47+
max_tokens=1024,
48+
)
49+
50+
return video_request.choices[0].message.content
51+
52+
53+
class Args:
54+
"""
55+
Class to store configuration arguments.
56+
"""
57+
def __init__(self, frame_limit=30, force_sample=False):
58+
self.frame_limit = frame_limit # Max number of frames to retrieve
59+
self.force_sample = force_sample # Whether to force uniform sampling
60+
61+
62+
# Function to capture frames from the camera until the user presses Enter
63+
def load_camera_frames_until_enter(args):
64+
global history_time # To maintain across multiple captures
65+
66+
cap = cv2.VideoCapture(0) # 0 is the ID for the default camera
67+
if not cap.isOpened():
68+
print("Error: Could not access the camera.")
69+
return None, None, None
70+
71+
fps = cap.get(cv2.CAP_PROP_FPS) or 30 # Default to 30 FPS if unable to retrieve FPS
72+
frame_count = 0
73+
74+
print("Video capturing started. Press 'Enter' in the console to stop capturing.")
75+
76+
while True:
77+
ret, frame = cap.read()
78+
if not ret:
79+
print("Error: Could not read frame from camera.")
80+
break
81+
82+
frame_count += 1
83+
cur_frame_time = frame_count / fps
84+
85+
video_frames.append(frame)
86+
frame_times.append(cur_frame_time + history_time)
87+
88+
# Display the frame
89+
cv2.imshow('Camera Feed', frame)
90+
91+
# Add cv2.waitKey to ensure the window remains visible
92+
if cv2.waitKey(1) & 0xFF == ord('q'):
93+
break
94+
95+
# Check if user pressed 'Enter' in the console
96+
if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
97+
input() # Consume the "Enter" key press
98+
print("Video capture stopped.")
99+
break
100+
101+
cap.release()
102+
cv2.destroyAllWindows() # Close the camera feed window
103+
104+
history_time = frame_times[-1] if frame_times else history_time
105+
106+
# Sample frames
107+
total_frames = len(video_frames)
108+
print(f"Total Frames Captured: {total_frames}")
109+
110+
if total_frames > args.frame_limit:
111+
sample_indices = np.linspace(0, total_frames - 1, args.frame_limit, dtype=int)
112+
sampled_frames = [video_frames[i] for i in sample_indices]
113+
sampled_times = [frame_times[i] for i in sample_indices]
114+
else:
115+
sampled_frames = video_frames
116+
sampled_times = frame_times
117+
118+
# import pdb; pdb.set_trace()
119+
frame_times_str = ",".join([f"{t:.2f}s" for t in sampled_times])
120+
return np.array(sampled_frames), frame_times_str, history_time
121+
122+
123+
# Function to stream video, process it, and answer a user question
124+
def stream_camera_and_ask_question(args):
125+
video_frames, frame_times, video_time = load_camera_frames_until_enter(args)
126+
127+
if video_frames is None:
128+
print("Error capturing video frames.")
129+
return
130+
131+
question = input("Press the query for current video: ").strip().lower()
132+
133+
print("question: ", question)
134+
image_base64 = encode_image(video_frames)
135+
# import pdb; pdb.set_trace()
136+
response = request_server(question, image_base64)
137+
138+
print(f"Model's Answer: {response}")
139+
print(f"Video Duration: 0 to {video_time:.2f} seconds")
140+
print(f"Frame Times: {frame_times}")
141+
142+
return response
143+
144+
145+
# Main loop to keep the system running and waiting for user input
146+
def main_loop():
147+
question = "Please describe this video."
148+
args = Args(frame_limit=64, force_sample=True)
149+
150+
while True:
151+
answer = stream_camera_and_ask_question(args)
152+
if answer is None:
153+
print("Exiting the loop.")
154+
break
155+
156+
user_input = input("Press 'Enter' to capture again, or 'q' to quit: ").strip().lower()
157+
if user_input == "q":
158+
print("Quitting the demo.")
159+
break
160+
161+
# Close all OpenCV windows after the user quits
162+
cv2.destroyAllWindows()
163+
164+
165+
if __name__ == "__main__":
166+
main_loop()

playground/demo/video_demo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def run_inference(args):
153153
else:
154154
args.force_sample = False
155155

156+
# import pdb;pdb.set_trace()
157+
156158
if getattr(model.config, "add_time_instruction", None) is not None:
157159
args.add_time_instruction = model.config.add_time_instruction
158160
else:
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#!/bin/bash
2+
3+
4+
# You should complete the path of the following attributes:
5+
PROJECT_ROOT="XXXX"
6+
## This could a yaml file for multiple files or a json file for a single file
7+
DATA_PATH="XXXX"
8+
IMAGE_FOLDER="XXXX"
9+
VIDEO_FOLDER="XXXX"
10+
11+
12+
export PYTHONWARNINGS="ignore"
13+
14+
15+
############### Prepare Envs #################
16+
cd $PROJECT_ROOT
17+
python3 -m pip install --upgrade pip
18+
python3 -m pip install -e ".[train]"
19+
20+
python3 -m pip install ninja
21+
python3 -m pip install flash-attn --no-build-isolation
22+
alias python=python3
23+
############### Show Envs ####################
24+
25+
nvidia-smi
26+
# 取 worker0 第一个 port
27+
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
28+
port=${ports[0]}
29+
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2222}" | awk -F',' '{print $1}')"
30+
31+
echo "total workers: ${ARNOLD_WORKER_NUM}"
32+
echo "cur worker id: ${ARNOLD_ID}"
33+
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
34+
echo "master ip: ${METIS_WORKER_0_HOST}"
35+
echo "master port: ${port}"
36+
echo "master port in cmd: ${port_in_cmd}"
37+
38+
export OMP_NUM_THREADS=8
39+
export NCCL_IB_DISABLE=0
40+
export NCCL_IB_GID_INDEX=3
41+
# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE}
42+
export NCCL_SOCKET_IFNAME=eth0
43+
export NCCL_DEBUG=WARN
44+
45+
PORT=26000
46+
GPUS="0,1,2,3,4,5,6,7"
47+
48+
################ Arnold Jobs ################
49+
50+
LLM_VERSION="Qwen/Qwen2-72B-Instruct"
51+
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
52+
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
53+
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
54+
55+
56+
# Stage For video
57+
PROMPT_VERSION="qwen_1_5"
58+
MID_RUN_NAME="llava_next_video-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video"
59+
PREV_STAGE_CHECKPOINT=""
60+
echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}"
61+
echo "MID_RUN_NAME: ${MID_RUN_NAME}"
62+
63+
64+
ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
65+
llava/train/train_mem.py \
66+
--deepspeed scripts/zero3.json \
67+
--model_name_or_path $PREV_STAGE_CHECKPOINT \
68+
--version $PROMPT_VERSION \
69+
--data_path ${DATA_PATH} \
70+
--image_folder ${IMAGE_FOLDER} \
71+
--video_folder ${VIDEO_FOLDER} \
72+
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
73+
--mm_vision_tower_lr=2e-6 \
74+
--vision_tower ${VISION_MODEL_VERSION} \
75+
--mm_projector_type mlp2x_gelu \
76+
--mm_vision_select_layer -2 \
77+
--mm_use_im_start_end False \
78+
--mm_use_im_patch_token False \
79+
--group_by_modality_length True \
80+
--image_aspect_ratio anyres_max_9 \
81+
--image_grid_pinpoints "(1x1),...,(6x6)" \
82+
--mm_patch_merge_type spatial_unpad \
83+
--bf16 True \
84+
--run_name $MID_RUN_NAME \
85+
--output_dir ./work_dirs/$MID_RUN_NAME \
86+
--num_train_epochs 1 \
87+
--per_device_train_batch_size 1 \
88+
--per_device_eval_batch_size 4 \
89+
--gradient_accumulation_steps 2 \
90+
--evaluation_strategy "no" \
91+
--save_strategy "steps" \
92+
--save_steps 500 \
93+
--save_total_limit 1 \
94+
--learning_rate 1e-5 \
95+
--weight_decay 0. \
96+
--warmup_ratio 0.03 \
97+
--lr_scheduler_type "cosine" \
98+
--logging_steps 1 \
99+
--tf32 True \
100+
--model_max_length 12768 \
101+
--gradient_checkpointing True \
102+
--dataloader_num_workers 2 \
103+
--lazy_preprocess True \
104+
--report_to wandb \
105+
--torch_compile True \
106+
--torch_compile_backend "inductor" \
107+
--dataloader_drop_last True \
108+
--frames_upbound 32 \
109+
--mm_newline_position grid \
110+
--add_time_instruction True \
111+
--force_sample True \
112+
--mm_spatial_pool_stride 2
113+
exit 0;

0 commit comments

Comments
 (0)