Skip to content

Commit 9404e45

Browse files
authored
chore: bump version to v2.0.0rc2
1 parent d907922 commit 9404e45

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

edgelab/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '2.0.0rc1'
1+
__version__ = '2.0.0rc2'
22
short_version = __version__
33

44

requirements/mmlab.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
mmcls>=1.0.0.rc6
33
mmcv>=2.0.0
44
mmdet>=3.0.0, <3.1.0 # mmyolo currently does not support mmdet 3.1.0
5-
mmengine>=0.7.2
5+
mmengine>=0.8.2
66
mmpose>=1.0.0
77

88
mmyolo@git+https://github.com/mjq2020/mmyolo

scripts/test_functional.sh

100644100755
File mode changed.

tools/train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def build_config(args):
178178

179179
def main():
180180
from mmengine.analysis import get_model_complexity_info
181-
from mmengine.device import get_device
182181

183182
args = parse_args()
184183
args = verify_args(args)
@@ -193,15 +192,16 @@ def main():
193192

194193
runner = RUNNERS.build(cfg)
195194

196-
device = get_device()
197-
dummy_inputs = torch.randn(*args.input_shape, device=device)
198-
model = runner.model.to(device=device)
195+
model = runner.model.to('cpu')
199196
model.eval()
200197

201-
analysis_results = get_model_complexity_info(model=model, inputs=(dummy_inputs,))
198+
analysis_results = get_model_complexity_info(model=model, input_shape=tuple(args.input_shape[1:]))
202199

203-
print('Model Flops:{}'.format(analysis_results['flops_str']))
204-
print('Model Parameters:{}'.format(analysis_results['params_str']))
200+
print('=' * 40)
201+
print(f"{'Input Shape':^20}:{str(args.input_shape):^20}")
202+
print(f"{'Model Flops':^20}:{analysis_results['flops_str']:^20}")
203+
print(f"{'Model Parameters':^20}:{analysis_results['params_str']:^20}")
204+
print('=' * 40)
205205

206206
runner.train()
207207

0 commit comments

Comments
 (0)