diff --git a/mmcv/utils/ext_loader.py b/mmcv/utils/ext_loader.py index a31e107dfe..83ac5ed6fa 100644 --- a/mmcv/utils/ext_loader.py +++ b/mmcv/utils/ext_loader.py @@ -7,12 +7,18 @@ import torch +from .op_input_info_logger import OpInputInfoLogger + if torch.__version__ != 'parrots': def load_ext(name, funcs): ext = importlib.import_module('mmcv.' + name) for fun in funcs: assert hasattr(ext, fun), f'{fun} miss in module {name}' + if os.getenv('MMCV_OPS_PRINT', '0') == '1': + if isinstance(getattr(ext, fun), OpInputInfoLogger): + continue + setattr(ext, fun, OpInputInfoLogger(getattr(ext, fun))) return ext else: from parrots import extension diff --git a/mmcv/utils/op_input_info_logger.py b/mmcv/utils/op_input_info_logger.py new file mode 100644 index 0000000000..0458382c0f --- /dev/null +++ b/mmcv/utils/op_input_info_logger.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from collections import OrderedDict + +import torch + + +class OpInputInfoLogger: + + def __init__(self, op): + self.op = op + self.op_name = self.op.__name__ + print(f'Wrap mmcv.ops.{self.op_name} with OpsInfoLogger') + + def _get_input_info(self, *args, **kwargs): + input_info = OrderedDict() + for i, arg in enumerate(args): + input_info[f'arg_{i}'] = arg + for name, value in kwargs.items(): + input_info[name] = value + return input_info + + def _dump_input_info(self, input_info): + info = dict() + info[f'mmcv.ops.{self.op_name}'] = OrderedDict() + for name, param in input_info.items(): + if isinstance(param, torch.Tensor): + info[f'mmcv.ops.{self.op_name}'][name] = { + 'shape': str(param.shape), + 'dtype': str(param.dtype), + } + else: + info[f'mmcv.ops.{self.op_name}'][name] = { + 'value': str(param), + 'type': str(type(param)), + } + with open('ops_input_info.jsonl', 'a') as f: + f.write(json.dumps(info) + '\n') + + def __call__(self, *args, **kwargs): + input_info = self._get_input_info(*args, **kwargs) + self._dump_input_info(input_info) + return self.op(*args, **kwargs)