11import json
22import os
3- from typing import Optional , Union
3+ import os .path as osp
4+ from typing import List , Optional , Union
45
6+ import cbor
7+ import numpy as np
58from mmcls .datasets import CustomDataset
69
710from edgelab .registry import DATASETS
@@ -17,6 +20,11 @@ def __init__(
1720 metainfo : Optional [dict ] = None ,
1821 data_root : str = '' ,
1922 data_prefix : Union [str , dict ] = '' ,
23+ window_size : int = 80 ,
24+ stride : int = 30 ,
25+ retention : float = 0.8 ,
26+ # source: str = 'EI',
27+ flatten : bool = True ,
2028 multi_label : bool = False ,
2129 ** kwargs ,
2230 ):
@@ -26,6 +34,12 @@ def __init__(
2634 self .data_root = data_root
2735 self .ann_file = ann_file
2836 self .data_prefix = data_prefix
37+ self .window_size = window_size
38+ self .stride = stride
39+ self .retention = retention
40+ self .flatten = flatten
41+
42+ self .data_dir = osp .join (self .data_root , self .data_prefix )
2943
3044 self .info_lables = json .load (open (os .path .join (self .data_root , self .data_prefix , self .ann_file )))
3145
@@ -58,7 +72,6 @@ def _find_samples(self):
5872 gt_label = j
5973 break
6074 samples .append ((filename , gt_label ))
61- print (samples )
6275 return samples
6376
6477 def load_data_list (self ):
@@ -75,12 +88,56 @@ def load_data_list(self):
7588
7689 data_list = []
7790 for filename , gt_label in samples :
78- img_path = os .path .join (self .img_prefix , filename )
79- info = {'file_path' : img_path , 'gt_label' : int (gt_label )}
80- data_list .append (info )
91+ ann_path = os .path .join (self .data_dir , filename )
92+ data_list .extend (
93+ [{'data' : np .asanyarray ([data ]), 'gt_label' : int (gt_label )} for data in self .read_split_data (ann_path )]
94+ )
8195
8296 return data_list
8397
98+ def read_split_data (self , file_path : str ) -> List :
99+ if file_path .lower ().endswith ('.cbor' ):
100+ with open (file_path , 'rb' ) as f :
101+ data = cbor .loads (f .read ())
102+ elif file_path .lower ().endswith ('.json' ):
103+ with open (file_path , 'r' ) as f :
104+ data = json .load (f )
105+
106+ values = np .asanyarray (data ['payload' ]['values' ])
107+
108+ result = []
109+ values_len = len (values )
110+ if values_len <= self .window_size :
111+ result .append (self .pad_data (values , self .window_size ).transpose (0 , 1 ).reshape (- 1 ))
112+ else :
113+ indexes = range (0 , values_len , self .stride )
114+ for i in indexes :
115+ if (values_len - i + 1 ) < self .window_size or i == indexes [- 1 ]:
116+ if self .retention * self .window_size < (values_len - i + 1 ):
117+ data = self .pad_data (values [i :], self .window_size )
118+ else :
119+ continue
120+ else :
121+ end = i + self .window_size
122+ if end >= values_len :
123+ if self .retention * self .window_size < (values_len - i + 1 ):
124+ data = self .pad_data (values [i :], self .window_size )
125+ else :
126+ continue
127+ else :
128+ data = values [i :end ]
129+ if self .flatten :
130+ data = data .transpose (0 , 1 ).reshape (- 1 )
131+ result .append (data )
132+ return result
133+
134+ def pad_data (self , data : np .asanyarray , total_len : int , mode = 'constant' , pad_val = 0 ) -> np .array :
135+ pad_len = total_len - len (data )
136+ front = pad_len // 2
137+ arfter = pad_len - front
138+ data = np .pad (data , ((front , arfter ), (0 , 0 )), mode = mode , constant_values = pad_val )
139+ return data
140+
84141 def is_valid_file (self , filename : str ) -> bool :
85142 """Check if a file is a valid sample."""
86143 return True
0 commit comments