@@ -46,39 +46,46 @@ def load_data_list(self) -> List[dict]:
4646 data_list = super ().load_data_list ()
4747
4848 # split text to several slices because of over-length
49- input_ids , bboxes , labels = [], [], []
50- segment_ids , position_ids = [], []
51- image_path = []
49+ split_text_data_list = []
5250 for i in range (len (data_list )):
5351 start = 0
5452 cur_iter = 0
5553 while start < len (data_list [i ]['input_ids' ]):
5654 end = min (start + 510 , len (data_list [i ]['input_ids' ]))
57-
58- input_ids .append ([self .tokenizer .cls_token_id ] +
59- data_list [i ]['input_ids' ][start :end ] +
60- [self .tokenizer .sep_token_id ])
61- bboxes .append ([[0 , 0 , 0 , 0 ]] +
62- data_list [i ]['bboxes' ][start :end ] +
63- [[1000 , 1000 , 1000 , 1000 ]])
64- labels .append ([- 100 ] + data_list [i ]['labels' ][start :end ] +
65- [- 100 ])
66-
67- cur_segment_ids = self .get_segment_ids (bboxes [- 1 ])
68- cur_position_ids = self .get_position_ids (cur_segment_ids )
69- segment_ids .append (cur_segment_ids )
70- position_ids .append (cur_position_ids )
71- image_path .append (
72- os .path .join (self .data_root , data_list [i ]['img_path' ]))
55+ # get input_ids
56+ input_ids = [self .tokenizer .cls_token_id ] + \
57+ data_list [i ]['input_ids' ][start :end ] + \
58+ [self .tokenizer .sep_token_id ]
59+ # get bboxes
60+ bboxes = [[0 , 0 , 0 , 0 ]] + \
61+ data_list [i ]['bboxes' ][start :end ] + \
62+ [[1000 , 1000 , 1000 , 1000 ]]
63+ # get labels
64+ labels = [- 100 ] + data_list [i ]['labels' ][start :end ] + [- 100 ]
65+ # get segment_ids
66+ segment_ids = self .get_segment_ids (bboxes )
67+ # get position_ids
68+ position_ids = self .get_position_ids (segment_ids )
69+ # get img_path
70+ img_path = os .path .join (self .data_root ,
71+ data_list [i ]['img_path' ])
72+ # get attention_mask
73+ attention_mask = [1 ] * len (input_ids )
74+
75+ data_info = {}
76+ data_info ['input_ids' ] = input_ids
77+ data_info ['bboxes' ] = bboxes
78+ data_info ['labels' ] = labels
79+ data_info ['segment_ids' ] = segment_ids
80+ data_info ['position_ids' ] = position_ids
81+ data_info ['img_path' ] = img_path
82+ data_info ['attention_mask ' ] = attention_mask
83+ split_text_data_list .append (data_info )
7384
7485 start = end
7586 cur_iter += 1
7687
77- assert len (input_ids ) == len (bboxes ) == len (labels ) == len (
78- segment_ids ) == len (position_ids )
79- assert len (segment_ids ) == len (image_path )
80-
81- return data_list
88+ return split_text_data_list
8289
8390 def parse_data_info (self , raw_data_info : dict ) -> Union [dict , List [dict ]]:
8491 instances = raw_data_info ['instances' ]
0 commit comments