Skip to content

Commit d907922

Browse files
authored
Merge pull request #109 from mjq2020/dev
Optimizing Dataset Sampling
2 parents 325d7f7 + 478c1b5 commit d907922

File tree

4 files changed

+79
-17
lines changed

4 files changed

+79
-17
lines changed

configs/accelerometer/3axes_accelerometer_62.5Hz_1s_classify.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@
44

55
num_classes = 3
66
num_axes = 3
7-
frequency = 62.5
8-
window = 1000
7+
window_size = 30
8+
stride = 20
99

1010
model = dict(
1111
type='AccelerometerClassifier',
1212
backbone=dict(
1313
type='AxesNet',
1414
num_axes=num_axes,
15-
frequency=frequency,
16-
window=window,
15+
window_size=window_size,
1716
num_classes=num_classes,
1817
),
1918
head=dict(
20-
type='edgelab.ClsHead',
19+
type='edgelab.AxesClsHead',
2120
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
2221
topk=(1, 5),
2322
),
@@ -29,15 +28,15 @@
2928
batch_size = 1
3029
workers = 1
3130

32-
shape = num_classes * int(62.5 * 1000 / 1000)
31+
shape = [1, num_axes * window_size]
3332

3433
train_pipeline = [
35-
dict(type='edgelab.LoadSensorFromFile'),
34+
# dict(type='edgelab.LoadSensorFromFile'),
3635
dict(type='edgelab.PackSensorInputs'),
3736
]
3837

3938
test_pipeline = [
40-
dict(type='edgelab.LoadSensorFromFile'),
39+
# dict(type='edgelab.LoadSensorFromFile'),
4140
dict(type='edgelab.PackSensorInputs'),
4241
]
4342

@@ -49,6 +48,8 @@
4948
data_root=data_root,
5049
data_prefix='training',
5150
ann_file='info.labels',
51+
window_size=window_size,
52+
stride=stride,
5253
pipeline=train_pipeline,
5354
),
5455
sampler=dict(type='DefaultSampler', shuffle=True),
@@ -61,6 +62,8 @@
6162
dataset=dict(
6263
type=dataset_type,
6364
data_root=data_root,
65+
window_size=window_size,
66+
stride=stride,
6467
data_prefix='testing',
6568
ann_file='info.labels',
6669
pipeline=test_pipeline,

edgelab/datasets/sensordataset.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import json
22
import os
3-
from typing import Optional, Union
3+
import os.path as osp
4+
from typing import List, Optional, Union
45

6+
import cbor
7+
import numpy as np
58
from mmcls.datasets import CustomDataset
69

710
from edgelab.registry import DATASETS
@@ -17,6 +20,11 @@ def __init__(
1720
metainfo: Optional[dict] = None,
1821
data_root: str = '',
1922
data_prefix: Union[str, dict] = '',
23+
window_size: int = 80,
24+
stride: int = 30,
25+
retention: float = 0.8,
26+
# source: str = 'EI',
27+
flatten: bool = True,
2028
multi_label: bool = False,
2129
**kwargs,
2230
):
@@ -26,6 +34,12 @@ def __init__(
2634
self.data_root = data_root
2735
self.ann_file = ann_file
2836
self.data_prefix = data_prefix
37+
self.window_size = window_size
38+
self.stride = stride
39+
self.retention = retention
40+
self.flatten = flatten
41+
42+
self.data_dir = osp.join(self.data_root, self.data_prefix)
2943

3044
self.info_lables = json.load(open(os.path.join(self.data_root, self.data_prefix, self.ann_file)))
3145

@@ -58,7 +72,6 @@ def _find_samples(self):
5872
gt_label = j
5973
break
6074
samples.append((filename, gt_label))
61-
print(samples)
6275
return samples
6376

6477
def load_data_list(self):
@@ -75,12 +88,56 @@ def load_data_list(self):
7588

7689
data_list = []
7790
for filename, gt_label in samples:
78-
img_path = os.path.join(self.img_prefix, filename)
79-
info = {'file_path': img_path, 'gt_label': int(gt_label)}
80-
data_list.append(info)
91+
ann_path = os.path.join(self.data_dir, filename)
92+
data_list.extend(
93+
[{'data': np.asanyarray([data]), 'gt_label': int(gt_label)} for data in self.read_split_data(ann_path)]
94+
)
8195

8296
return data_list
8397

98+
def read_split_data(self, file_path: str) -> List:
99+
if file_path.lower().endswith('.cbor'):
100+
with open(file_path, 'rb') as f:
101+
data = cbor.loads(f.read())
102+
elif file_path.lower().endswith('.json'):
103+
with open(file_path, 'r') as f:
104+
data = json.load(f)
105+
106+
values = np.asanyarray(data['payload']['values'])
107+
108+
result = []
109+
values_len = len(values)
110+
if values_len <= self.window_size:
111+
result.append(self.pad_data(values, self.window_size).transpose(0, 1).reshape(-1))
112+
else:
113+
indexes = range(0, values_len, self.stride)
114+
for i in indexes:
115+
if (values_len - i + 1) < self.window_size or i == indexes[-1]:
116+
if self.retention * self.window_size < (values_len - i + 1):
117+
data = self.pad_data(values[i:], self.window_size)
118+
else:
119+
continue
120+
else:
121+
end = i + self.window_size
122+
if end >= values_len:
123+
if self.retention * self.window_size < (values_len - i + 1):
124+
data = self.pad_data(values[i:], self.window_size)
125+
else:
126+
continue
127+
else:
128+
data = values[i:end]
129+
if self.flatten:
130+
data = data.transpose(0, 1).reshape(-1)
131+
result.append(data)
132+
return result
133+
134+
def pad_data(self, data: np.asanyarray, total_len: int, mode='constant', pad_val=0) -> np.array:
135+
pad_len = total_len - len(data)
136+
front = pad_len // 2
137+
arfter = pad_len - front
138+
data = np.pad(data, ((front, arfter), (0, 0)), mode=mode, constant_values=pad_val)
139+
return data
140+
84141
def is_valid_file(self, filename: str) -> bool:
85142
"""Check if a file is a valid sample."""
86143
return True

edgelab/models/backbones/AxesNet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66

77
@MODELS.register_module()
88
class AxesNet(nn.Module):
9-
def __init__(
10-
self, num_axes=3, frequency=62.5, window=1000, num_classes=-1 # axes number # sample frequency # window size
11-
):
9+
def __init__(self, num_axes=3, window_size=80, num_classes=-1): # axes number # sample frequency # window size
1210
super().__init__()
1311
self.num_classes = num_classes
14-
self.intput_feature = num_axes * int(frequency * window / 1000)
12+
self.intput_feature = num_axes * window_size
1513
liner_feature = self.liner_feature_fit()
1614
self.fc1 = nn.Linear(in_features=self.intput_feature, out_features=liner_feature, bias=True)
1715
self.fc2 = nn.Linear(in_features=liner_feature, out_features=liner_feature, bias=True)
@@ -23,6 +21,7 @@ def liner_feature_fit(self):
2321
return (int(self.intput_feature / 1024) + 1) * 256
2422

2523
def forward(self, x):
24+
x = x[0] if isinstance(x, list) else x
2625
x = F.relu(self.fc1(x))
2726
x = F.relu(self.fc2(x))
2827

requirements/base.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# common
22
albumentations>=1.3.0
3+
4+
# sensor
5+
cbor
36
numpy>=1.23.5
47

58
# vision

0 commit comments

Comments
 (0)