Skip to content

Commit 8e6ad23

Browse files
authored
feat: swift yolo mbnv4
* Squashed commit of the following: commit 87250f5 Author: mjq2020 <mjqx2011@163.com> Date: Fri Apr 26 10:33:07 2024 +0000 add: mobilenetv4 backbone commit 0771f18 Merge: 8e0b2f7 7f9c4e0 Author: mjq2020 <mjqx2011@163.com> Date: Fri Apr 26 10:31:57 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit 8e0b2f7 Merge: ac0f39d 9b00e64 Author: mjq2020 <mjqx2011@163.com> Date: Fri Apr 19 06:22:49 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit ac0f39d Merge: c4ea712 1f67493 Author: mjq2020 <mjqx2011@163.com> Date: Mon Apr 1 10:02:11 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit c4ea712 Author: mjq2020 <mjqx2011@163.com> Date: Mon Apr 1 10:00:39 2024 +0000 Fix: cls loss weight too high commit b87fc04 Merge: f146c73 ee72f81 Author: mjq2020 <mjqx2011@163.com> Date: Mon Apr 1 09:54:45 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit f146c73 Merge: c068454 289360c Author: mjq2020 <mjqx2011@163.com> Date: Tue Mar 19 02:16:47 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit c068454 Author: mjq2020 <mjqx2011@163.com> Date: Mon Mar 18 07:05:34 2024 +0000 Optim: model inference display commit fc8874f Author: mjq2020 <mjqx2011@163.com> Date: Mon Mar 18 07:03:33 2024 +0000 Fix: data type bug commit f1d76fc Merge: 1378e1b 3c61e3e Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Thu Mar 14 18:42:18 2024 +0800 Merge branch 'Seeed-Studio:main' into main commit 1378e1b Merge: 8c8ffd7 31c5291 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Tue Jan 30 11:41:16 2024 +0800 Merge branch 'Seeed-Studio:main' into main commit 8c8ffd7 Merge: c67ed2d ebb1ec2 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Fri Oct 13 11:09:55 2023 +0800 Merge branch 'Seeed-Studio:main' into main commit c67ed2d Merge: d70e424 9be0612 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Sat Sep 23 16:18:57 2023 +0800 Merge branch 'Seeed-Studio:main' into main * feat: mobilenetv4 swift yolo backbone * chore: modify mbnv4 medium/large outputs
1 parent 343a2cb commit 8e6ad23

File tree

3 files changed

+200
-18
lines changed

3 files changed

+200
-18
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
2+
_base_ = ['./base_arch.py']
3+
4+
# ========================Suggested optional parameters========================
5+
# MODEL
6+
num_classes = 71
7+
deepen_factor = 0.33
8+
widen_factor = 1
9+
10+
# DATA
11+
dataset_type = 'sscma.CustomYOLOv5CocoDataset'
12+
train_ann = 'train/_annotations.coco.json'
13+
train_data = 'train/' # Prefix of train image path
14+
val_ann = 'valid/_annotations.coco.json'
15+
val_data = 'valid/' # Prefix of val image path
16+
17+
# dataset link: https://universe.roboflow.com/team-roboflow/coco-128
18+
data_root = 'https://universe.roboflow.com/ds/z5UOcgxZzD?key=bwx9LQUT0t'
19+
height = 192
20+
width = 192
21+
batch = 16
22+
workers = 2
23+
val_batch = batch
24+
val_workers = workers
25+
imgsz = (width, height)
26+
27+
# TRAIN
28+
persistent_workers = True
29+
30+
# ================================END=================================
31+
32+
# DATA
33+
affine_scale = 0.5
34+
# MODEL
35+
strides = [8, 16, 32]
36+
37+
anchors = [
38+
[(10, 13), (16, 30), (33, 23)], # P3/8
39+
[(30, 61), (62, 45), (59, 119)], # P4/16
40+
[(116, 90), (156, 198), (373, 326)], # P5/32
41+
]
42+
43+
# default_scope = 'sscma'
44+
45+
model = dict(
46+
type='mmyolo.YOLODetector',
47+
backbone=dict(
48+
_delete_=True,
49+
type='sscma.MobileNetv4',
50+
arch='small'
51+
),
52+
neck=dict(
53+
type='mmyolo.YOLOv5PAFPN',
54+
deepen_factor=deepen_factor,
55+
widen_factor=widen_factor,
56+
in_channels=[64, 96, 128],
57+
out_channels=[64, 96, 128]
58+
),
59+
bbox_head=dict(
60+
head_module=dict(
61+
num_classes=num_classes,
62+
in_channels=[64, 96, 128],
63+
widen_factor=widen_factor,
64+
),
65+
),
66+
)
67+
68+
# ======================datasets==================
69+
70+
71+
batch_shapes_cfg = dict(
72+
type='BatchShapePolicy',
73+
batch_size=1,
74+
img_size=imgsz[0],
75+
# The image scale of padding should be divided by pad_size_divisor
76+
size_divisor=32,
77+
# Additional paddings for pixel scale
78+
extra_pad_ratio=0.5,
79+
)
80+
81+
albu_train_transforms = [
82+
dict(type='Blur', p=0.01),
83+
dict(type='MedianBlur', p=0.01),
84+
dict(type='ToGray', p=0.01),
85+
dict(type='CLAHE', p=0.01),
86+
]
87+
88+
pre_transform = [
89+
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
90+
dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'),
91+
]
92+
93+
train_pipeline = [
94+
*pre_transform,
95+
dict(type='Mosaic', img_scale=imgsz, pad_val=114.0, pre_transform=pre_transform, _scope_='sscma'),
96+
dict(
97+
type='YOLOv5RandomAffine',
98+
max_rotate_degree=0.0,
99+
max_shear_degree=0.0,
100+
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
101+
# imgsz is (width, height)
102+
border=(-imgsz[0] // 2, -imgsz[1] // 2),
103+
border_val=(114, 114, 114),
104+
_scope_='sscma'
105+
),
106+
dict(
107+
type='mmdet.Albu',
108+
transforms=albu_train_transforms,
109+
bbox_params=dict(type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
110+
keymap={'img': 'image', 'gt_bboxes': 'bboxes'},
111+
),
112+
dict(type='YOLOv5HSVRandomAug', _scope_='sscma'),
113+
dict(type='mmdet.RandomFlip', prob=0.5),
114+
dict(
115+
type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction')
116+
),
117+
]
118+
119+
train_dataloader = dict(
120+
batch_size=batch,
121+
num_workers=workers,
122+
persistent_workers=persistent_workers,
123+
pin_memory=True,
124+
sampler=dict(type='DefaultSampler', shuffle=True),
125+
dataset=dict(
126+
type=dataset_type,
127+
data_root=data_root,
128+
ann_file=train_ann,
129+
data_prefix=dict(img=train_data),
130+
filter_cfg=dict(filter_empty_gt=False, min_size=32),
131+
pipeline=train_pipeline,
132+
),
133+
)
134+
135+
test_pipeline = [
136+
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
137+
dict(type='YOLOv5KeepRatioResize', scale=imgsz, _scope_='sscma'),
138+
dict(type='sscma.LetterResize', scale=imgsz, allow_scale_up=False, pad_val=dict(img=114), _scope_='sscma'),
139+
dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'),
140+
dict(
141+
type='mmdet.PackDetInputs',
142+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param'),
143+
),
144+
]
145+
146+
val_dataloader = dict(
147+
batch_size=val_batch,
148+
num_workers=val_workers,
149+
persistent_workers=persistent_workers,
150+
pin_memory=True,
151+
drop_last=False,
152+
sampler=dict(type='DefaultSampler', shuffle=False),
153+
dataset=dict(
154+
type=dataset_type,
155+
data_root=data_root,
156+
test_mode=True,
157+
data_prefix=dict(img=val_data),
158+
ann_file=val_ann,
159+
pipeline=test_pipeline,
160+
batch_shapes_cfg=batch_shapes_cfg,
161+
),
162+
)
163+
164+
test_dataloader = val_dataloader

sscma/models/backbones/MobileNetv4.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from torch import Tensor
66

77
from sscma.models.base.general import ConvNormActivation
8-
from sscma.registry import BACKBONES
8+
from sscma.registry import MODELS
9+
910
from sscma.models.layers.nn_blocks import (
1011
UniversalInvertedBottleneckBlock,
1112
InvertedBottleneckBlock,
@@ -129,7 +130,8 @@ def mhsa_large_12px():
129130
)
130131

131132

132-
@BACKBONES.register_module()
133+
134+
@MODELS.register_module()
133135
class MobileNetv4(nn.Module):
134136
'''
135137
Architecture: https://arxiv.org/abs/2404.10518
@@ -144,7 +146,7 @@ class MobileNetv4(nn.Module):
144146
'small': [
145147
('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2
146148
('fused_ib', 'ReLU', 3, None, None, False, 2, 32, 1, False), # 1/4
147-
('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, False), # 1/8
149+
('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, True), # 1/8
148150
('uib', 'ReLU', None, 5, 5, True, 2, 96, 3.0, False), # 1/16
149151
('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB
150152
('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB
@@ -193,7 +195,7 @@ class MobileNetv4(nn.Module):
193195
],
194196
'large': [
195197
('convbn', 'ReLU', 3, None, None, False, 2, 24, None, False),
196-
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, True),
198+
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, False),
197199
('uib', 'ReLU', None, 3, 5, True, 2, 96, 4.0, False),
198200
('uib', 'ReLU', None, 3, 3, True, 1, 96, 4.0, True),
199201
('uib', 'ReLU', None, 3, 5, True, 2, 192, 4.0, False),
@@ -227,9 +229,9 @@ class MobileNetv4(nn.Module):
227229
],
228230
'hybridmedium': [
229231
('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2
230-
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, True), # 1/4
232+
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, False), # 1/4
231233
('uib', 'ReLU', None, 3, 5, True, 2, 80, 4.0, False), # IB
232-
('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, False), # IB
234+
('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, True), # IB
233235
('uib', 'ReLU', None, 3, 5, True, 2, 160, 6.0, False), # IB
234236
('uib', 'ReLU', None, 0, 0, True, 1, 160, 2.0, False), # IB
235237
('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False), # IB
@@ -242,7 +244,7 @@ class MobileNetv4(nn.Module):
242244
('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False),
243245
mhsa_medium_24px(),
244246
('uib', 'ReLU', None, 3, 0, True, 1, 160, 4.0, True),
245-
('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, True),
247+
('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, False),
246248
('uib', 'ReLU', None, 5, 5, True, 1, 256, 4.0, False),
247249
('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False),
248250
('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False),
@@ -265,7 +267,7 @@ class MobileNetv4(nn.Module):
265267
],
266268
'hybridlarge': [
267269
('convbn', 'GELU', 3, None, None, False, 2, 24, None, False), # 1/2
268-
('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, True), # 1/4
270+
('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, False), # 1/4
269271
('uib', 'GELU', None, 3, 5, True, 2, 96, 4.0, False), # IB
270272
('uib', 'GELU', None, 3, 3, True, 1, 96, 4.0, True), # IB
271273
('uib', 'GELU', None, 3, 5, True, 2, 192, 4.0, False), # IB
@@ -283,7 +285,7 @@ class MobileNetv4(nn.Module):
283285
('uib', 'GELU', None, 5, 3, True, 1, 192, 4.0, False),
284286
mhsa_large_24px(),
285287
('uib', 'GELU', None, 3, 0, True, 1, 192, 4.0, True), # output
286-
('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, True),
288+
('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, False),
287289
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
288290
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
289291
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
@@ -326,23 +328,29 @@ def __init__(
326328

327329
self._output_stride: int = (1,)
328330

329-
self.blocks_setting = []
331+
self.block_settings = []
330332
for setting in arch_setting:
331333
if isinstance(setting, tuple):
332-
self.blocks_setting.append(BlockConfig(*setting, input_channels=input_channels))
334+
self.block_settings.append(BlockConfig(*setting, input_channels=input_channels))
333335
else:
334-
self.blocks_setting.append(BlockConfig(**setting, input_channels=input_channels))
335-
if self.blocks_setting[-1].output_channels is not None:
336-
input_channels = self.blocks_setting[-1].output_channels
336+
self.block_settings.append(BlockConfig(**setting, input_channels=input_channels))
337+
if self.block_settings[-1].output_channels is not None:
338+
input_channels = self.block_settings[-1].output_channels
337339

338-
self._forward_blocks = self.build_layers()
340+
last_output_block = 0
341+
for i, block in enumerate(self.block_settings):
342+
if block.isoutputblock:
343+
last_output_block = i
344+
345+
self._forward_blocks = self.build_layers()[: last_output_block + 1]
339346

340347
def build_layers(self):
341348
layers = []
342349
block: BlockConfig
343350
current_stride = 1
344351
rate = 1
345-
for block in self.blocks_setting:
352+
353+
for block in self.block_settings:
346354

347355
if not block.stride:
348356
block.stride = 1
@@ -355,6 +363,7 @@ def build_layers(self):
355363
layer_stride = block.stride
356364
layer_rate = 1
357365
current_stride *= block.stride
366+
358367
if block.block_name == 'convbn':
359368
layer = ConvNormActivation(
360369
block.input_channels,
@@ -422,9 +431,16 @@ def build_layers(self):
422431
)
423432
else:
424433
raise ValueError(f'block name "{block.block_name}" is not supported')
434+
425435
layers.append(layer)
436+
426437
return nn.Sequential(*layers)
427438

428439
def forward(self, x):
429-
x = self._forward_blocks(x)
430-
return x
440+
outs = []
441+
for cfg, blk in zip(self.block_settings, self._forward_blocks):
442+
x = blk(x)
443+
if cfg.isoutputblock:
444+
outs.append(x)
445+
446+
return tuple(outs)

sscma/models/backbones/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .EfficientNet import EfficientNet
44
from .MobileNetv2 import MobileNetv2
55
from .MobileNetv3 import MobileNetV3
6+
from .MobileNetv4 import MobileNetv4
67
from .pfld_mobilenet_v2 import PfldMobileNetV2
78
from .ShuffleNetV2 import ShuffleNetV2, CustomShuffleNetV2, FastShuffleNetV2
89
from .SoundNet import SoundNetRaw
@@ -17,6 +18,7 @@
1718
'CustomShuffleNetV2',
1819
'AxesNet',
1920
'MobileNetV3',
21+
'MobileNetv4',
2022
'ShuffleNetV2',
2123
'SqueezeNet',
2224
'EfficientNet',

0 commit comments

Comments
 (0)