From 8986d9323743918fe660caf85f739db5145898d6 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Tue, 16 Apr 2024 09:42:40 +0800 Subject: [PATCH 1/2] [Feature] Add op input_info logger --- mmcv/utils/ext_loader.py | 4 +++ mmcv/utils/op_input_info_logger.py | 43 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 mmcv/utils/op_input_info_logger.py diff --git a/mmcv/utils/ext_loader.py b/mmcv/utils/ext_loader.py index a31e107dfe..477987b873 100644 --- a/mmcv/utils/ext_loader.py +++ b/mmcv/utils/ext_loader.py @@ -7,12 +7,16 @@ import torch +from .op_input_info_logger import OpInputInfoLogger + if torch.__version__ != 'parrots': def load_ext(name, funcs): ext = importlib.import_module('mmcv.' + name) for fun in funcs: assert hasattr(ext, fun), f'{fun} miss in module {name}' + if os.getenv('MMCV_OPS_PRINT', '0') == '1': + setattr(ext, fun, OpInputInfoLogger(getattr(ext, fun))) return ext else: from parrots import extension diff --git a/mmcv/utils/op_input_info_logger.py b/mmcv/utils/op_input_info_logger.py new file mode 100644 index 0000000000..0458382c0f --- /dev/null +++ b/mmcv/utils/op_input_info_logger.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from collections import OrderedDict + +import torch + + +class OpInputInfoLogger: + + def __init__(self, op): + self.op = op + self.op_name = self.op.__name__ + print(f'Wrap mmcv.ops.{self.op_name} with OpsInfoLogger') + + def _get_input_info(self, *args, **kwargs): + input_info = OrderedDict() + for i, arg in enumerate(args): + input_info[f'arg_{i}'] = arg + for name, value in kwargs.items(): + input_info[name] = value + return input_info + + def _dump_input_info(self, input_info): + info = dict() + info[f'mmcv.ops.{self.op_name}'] = OrderedDict() + for name, param in input_info.items(): + if isinstance(param, torch.Tensor): + info[f'mmcv.ops.{self.op_name}'][name] = { + 'shape': str(param.shape), + 'dtype': str(param.dtype), + } + else: + info[f'mmcv.ops.{self.op_name}'][name] = { + 'value': str(param), + 'type': str(type(param)), + } + with open('ops_input_info.jsonl', 'a') as f: + f.write(json.dumps(info) + '\n') + + def __call__(self, *args, **kwargs): + input_info = self._get_input_info(*args, **kwargs) + self._dump_input_info(input_info) + return self.op(*args, **kwargs) From 4b1c75c8e6fcbbb2dc94e4efd477b05a222b464b Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Tue, 23 Apr 2024 17:50:29 +0800 Subject: [PATCH 2/2] [Fix] Skip wrap when wrapped --- mmcv/utils/ext_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmcv/utils/ext_loader.py b/mmcv/utils/ext_loader.py index 477987b873..83ac5ed6fa 100644 --- a/mmcv/utils/ext_loader.py +++ b/mmcv/utils/ext_loader.py @@ -16,6 +16,8 @@ def load_ext(name, funcs): for fun in funcs: assert hasattr(ext, fun), f'{fun} miss in module {name}' if os.getenv('MMCV_OPS_PRINT', '0') == '1': + if isinstance(getattr(ext, fun), OpInputInfoLogger): + continue setattr(ext, fun, OpInputInfoLogger(getattr(ext, fun))) return ext else: