Skip to content

Commit ab28d4c

Browse files
author
Haozhe Qi
committed
fixed random seed in utils
1 parent 20a11bb commit ab28d4c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

llava/action/ek_eval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ def prepare_llava(pretrained):
143143
device_map = "auto"
144144

145145
overwrite_config = None
146-
if 'video' in pretrained or 'Video' in pretrained or '7b' in pretrained:
147-
overwrite_config = {'tie_word_embeddings': False, 'use_cache': True, "vocab_size": 152064}
148-
146+
if 'ov' not in pretrained:
147+
if 'video' in pretrained or 'Video' in pretrained or '7b' in pretrained:
148+
overwrite_config = {'tie_word_embeddings': False, 'use_cache': True, "vocab_size": 152064}
149+
else:
150+
pass
149151

150152
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained,
151153
None,

llava/action/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import json
2020
from llava.utils import rank0_print
2121

22+
# set random seed
23+
random.seed(42)
24+
2225
def remove_sub_nouns(nlp, narration, verb, nouns):
2326
narration = copy.deepcopy(narration)
2427
noun_list = ast.literal_eval(nouns)
@@ -431,7 +434,9 @@ def generate_multi_choice(self,
431434
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = False):
432435
# letters as A, B, C, D, .. Note we maximally support 26 letters
433436
letters = [chr(65+i) for i in range(26)][:k]
434-
answer_list = [vn for vn in mapping_vn2narration.keys()]
437+
answer_list = [vn for vn in mapping_vn2narration.keys()]
438+
439+
435440
wrong_answers = np.random.choice(answer_list, size = k-1, replace = False)
436441
answer_ids = [gt_vn] + list(wrong_answers)
437442
random.shuffle(answer_ids)

0 commit comments

Comments
 (0)