Skip to content

Commit ce4129e

Browse files
committed
WIP
1 parent 935fe80 commit ce4129e

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

action/chatgpt_utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,12 @@ def split_indices(indices, num_chunks):
4848

4949
return chunks
5050

51-
class GPTAnnotator:
52-
def __init__(self, ann_file, data_root, clip_length = 4):
53-
self.ann_file = ann_file
54-
self.data_root = data_root
51+
52+
class ChatGPT:
53+
def __init__(self, clip_length = 4):
5554
self.clip_length = clip_length
56-
data = []
57-
with open(ann_file, 'r') as f:
58-
for line in f:
59-
data.append(json.loads(line))
60-
self.data = data
61-
62-
def prepare_multiple_images(self, images):
55+
56+
def prepare_multi_images(self, images):
6357
"""
6458
6559
"""
@@ -91,13 +85,9 @@ def prepare_multiple_images(self, images):
9185
for encoded_image in encoded_image_list
9286
]
9387

94-
return multi_image_content
95-
88+
return multi_image_content
9689

9790
def extract_frames(self, data_root, vid_path, start_second, end_second):
98-
99-
100-
10191
frames, time_meta = avion_video_loader(data_root,
10292
vid_path,
10393
'MP4',
@@ -109,9 +99,31 @@ def extract_frames(self, data_root, vid_path, start_second, end_second):
10999
fast_rrc=False,
110100
fast_rcc = False,
111101
jitter = False)
112-
return frames, time_meta
102+
return frames, time_meta
113103

114104

105+
class GPTInferenceAnnotator(ChatGPT):
106+
"""
107+
Given the images, this class will annotate the video frames
108+
"""
109+
110+
111+
class GPTAugmentationAnnotator(ChatGPT):
112+
"""
113+
Given the train annotation from the EK100 dataset, this class will annotate the video frames
114+
that augments the gt annotations.
115+
"""
116+
117+
def __init__(self, ann_file, data_root, clip_length = 4):
118+
super().__init__(clip_length = clip_length)
119+
self.ann_file = ann_file
120+
self.data_root = data_root
121+
self.clip_length = clip_length
122+
data = []
123+
with open(ann_file, 'r') as f:
124+
for line in f:
125+
data.append(json.loads(line))
126+
self.data = data
115127

116128
def parse_conversation_from_train_convs(self, item):
117129
"""

0 commit comments

Comments
 (0)