Skip to content

Commit 0cacc5f

Browse files
committed
Fixed bugs
1 parent 57426c8 commit 0cacc5f

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

action/dataset.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,21 @@ def get_args_parser():
610610

611611
return parser
612612

613+
def prepare_llava():
614+
615+
import warnings
616+
from llava.model.builder import load_pretrained_model
617+
warnings.filterwarnings("ignore")
618+
# Load the OneVision model
619+
#pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov"
620+
model_name = "llava_qwen"
621+
622+
device_map = "auto"
623+
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
624+
625+
return tokenizer, model, image_processor, max_length
626+
627+
613628
if __name__ == '__main__':
614629
from moviepy.editor import ImageSequenceClip
615630
import torchvision
@@ -640,10 +655,24 @@ def get_args_parser():
640655

641656
# Set up logging
642657
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=f'llava_ov_{args.llava_num_frames}f_{args.llm_size}.log', filemode='w')
658+
659+
console_handler = logging.StreamHandler(sys.stdout)
660+
console_handler.setLevel(logging.INFO)
661+
662+
# Set the same format for console handler as well
663+
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
664+
665+
# Add the console handler to the root logger
666+
logging.getLogger().addHandler(console_handler)
667+
643668
logger = logging.getLogger(__name__)
644669

670+
pretrained = f"lmms-lab/llava-onevision-qwen2-{args.llm_size}-ov"
671+
672+
tokenizer, model, image_processor, max_length = prepare_llava()
673+
645674
for idx, (frames, gt) in tqdm(enumerate(val_dataloader)):
646-
pred = llava_inference(frames, gt, logger, num_frames=args.llava_num_frames, llm_size=args.llm_size)
675+
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, gt, num_frames=args.llava_num_frames)
647676

648677
# if valid letter is found in the prediction, then we will use that as the prediction
649678
found = False
@@ -671,4 +700,4 @@ def get_args_parser():
671700
# get final accuracy
672701
accuracy = np.mean(gts == preds)
673702
logger.info(f'Final accuracy: {accuracy:.4f}')
674-
703+

action/llava_ov_inference.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from operator import attrgetter
2-
from llava.model.builder import load_pretrained_model
2+
33
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
44
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
55
from llava.conversation import conv_templates, SeparatorStyle
@@ -13,22 +13,13 @@
1313
import warnings
1414
from decord import VideoReader, cpu
1515

16-
def llava_inference(video_frames, gt, logger, num_frames=16, llm_size='7b'):
1716

18-
warnings.filterwarnings("ignore")
19-
# Load the OneVision model
20-
pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov"
21-
logger.info(f"Loading model {pretrained}")
22-
model_name = "llava_qwen"
23-
device = "cuda"
24-
device_map = "auto"
25-
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
17+
def llava_inference(video_frames, tokenizer, model, image_processor, max_length, gt, num_frames=16):
2618

27-
model.eval()
19+
model.eval()
20+
device = "cuda"
2821
video_frames = video_frames[0]
29-
3022
temporal_stride = 16 // num_frames
31-
3223
video_frames = video_frames[::temporal_stride]
3324
image_tensors = []
3425
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()

run_EK100.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
python3 action/dataset.py \
22
--root /media/data/haozhe/VFM/EK100/EK100_320p_15sec_30fps_libx264 \
33
--train-metadata /media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv \
4-
--val-metadata /media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv > kitchen_test.out 2>&1
4+
--val-metadata /media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv \
5+
--llm_size 7b \
6+
--llava_num_frames 16 > kitchen_test.out 2>&1 \

0 commit comments

Comments
 (0)