|
| 1 | +from . import utils |
| 2 | +import argparse |
| 3 | +import importlib.util |
| 4 | +import inspect |
| 5 | +import torch |
| 6 | +import logging |
| 7 | +from pathlib import Path |
| 8 | +from typing import Type, Any |
| 9 | +import sys |
| 10 | +import json |
| 11 | +import base64 |
| 12 | +from contextlib import contextmanager |
| 13 | + |
| 14 | + |
| 15 | +def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: |
| 16 | + spec = importlib.util.spec_from_file_location("unnamed", file_path) |
| 17 | + unnamed = importlib.util.module_from_spec(spec) |
| 18 | + spec.loader.exec_module(unnamed) |
| 19 | + model_class = getattr(unnamed, class_name, None) |
| 20 | + return model_class |
| 21 | + |
| 22 | + |
| 23 | +def _convert_to_dict(config_str): |
| 24 | + if config_str is None: |
| 25 | + return {} |
| 26 | + config_str = base64.b64decode(config_str).decode("utf-8") |
| 27 | + config = json.loads(config_str) |
| 28 | + assert isinstance(config, dict), f"config should be a dict. {config_str=}" |
| 29 | + return config |
| 30 | + |
| 31 | + |
| 32 | +def _get_decorator(args): |
| 33 | + if args.decorator_config is None: |
| 34 | + return lambda model: model |
| 35 | + decorator_config = _convert_to_dict(args.decorator_config) |
| 36 | + if "decorator_path" not in decorator_config: |
| 37 | + return lambda model: model |
| 38 | + decorator_class = load_class_from_file( |
| 39 | + decorator_config["decorator_path"], class_name="RunModelDecorator" |
| 40 | + ) |
| 41 | + return decorator_class(decorator_config.get("decorator_config", {})) |
| 42 | + |
| 43 | + |
| 44 | +def main(args): |
| 45 | + model_path = args.model_path |
| 46 | + model_class = load_class_from_file( |
| 47 | + f"{model_path}/model.py", class_name="GraphModule" |
| 48 | + ) |
| 49 | + assert model_class is not None |
| 50 | + model = model_class() |
| 51 | + print(f"{model_path=}") |
| 52 | + |
| 53 | + model = _get_decorator(args)(model) |
| 54 | + |
| 55 | + inputs_params = utils.load_converted_from_text(f"{model_path}") |
| 56 | + params = inputs_params["weight_info"] |
| 57 | + state_dict = {k: utils.replay_tensor(v) for k, v in params.items()} |
| 58 | + |
| 59 | + model(**state_dict) |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + parser = argparse.ArgumentParser(description="load and run model") |
| 64 | + parser.add_argument( |
| 65 | + "--model-path", |
| 66 | + type=str, |
| 67 | + required=True, |
| 68 | + help="Path to folder e.g '../../samples/torch/resnet18'", |
| 69 | + ) |
| 70 | + parser.add_argument( |
| 71 | + "--decorator-config", |
| 72 | + type=str, |
| 73 | + required=False, |
| 74 | + default=None, |
| 75 | + help="decorator configuration string", |
| 76 | + ) |
| 77 | + args = parser.parse_args() |
| 78 | + main(args=args) |
0 commit comments