|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import math |
| 5 | +import importlib |
| 6 | +import inspect |
| 7 | +from typing import Type |
| 8 | +from dataclasses import dataclass, field |
| 9 | +from collections import defaultdict |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch.fx.passes.shape_prop import ShapeProp |
| 13 | +from graph_net.torch import utils |
| 14 | + |
| 15 | + |
| 16 | +def is_single_model_dir(model_dir): |
| 17 | + return os.path.isfile(f"{model_dir}/graph_net.json") |
| 18 | + |
| 19 | + |
| 20 | +def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: |
| 21 | + spec = importlib.util.spec_from_file_location("unnamed", file_path) |
| 22 | + unnamed = importlib.util.module_from_spec(spec) |
| 23 | + spec.loader.exec_module(unnamed) |
| 24 | + model_class = getattr(unnamed, class_name, None) |
| 25 | + return model_class |
| 26 | + |
| 27 | + |
| 28 | +def get_argument_types(model_class, func_name): |
| 29 | + arg_types = {} |
| 30 | + for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): |
| 31 | + if name == func_name: |
| 32 | + for arg_name, arg in inspect.signature(func).parameters.items(): |
| 33 | + if arg_name != "self": |
| 34 | + arg_types[arg_name] = ( |
| 35 | + None if arg.annotation is inspect._empty else arg.annotation |
| 36 | + ) |
| 37 | + return arg_types |
| 38 | + |
| 39 | + |
| 40 | +def get_input_dict(model_path, device): |
| 41 | + inputs_params = utils.load_converted_from_text(f"{model_path}") |
| 42 | + params = inputs_params["weight_info"] |
| 43 | + for tensor_meta in params.values(): |
| 44 | + if hasattr(tensor_meta, "device"): |
| 45 | + tensor_meta.device = device |
| 46 | + return { |
| 47 | + k: utils.replay_tensor(v).to(torch.device(device)) for k, v in params.items() |
| 48 | + } |
| 49 | + |
| 50 | + |
| 51 | +@dataclass |
| 52 | +class OpStat: |
| 53 | + op_name: str |
| 54 | + dtype: set[str] = field(default_factory=set) |
| 55 | + count: int = 0 |
| 56 | + |
| 57 | + |
| 58 | +def collect_op_stats(model, input_dict): |
| 59 | + # Use meta tensors as input to avoid actually running the model |
| 60 | + meta_input_dict = {} |
| 61 | + for name, x in input_dict.items(): |
| 62 | + meta_input_dict[name] = ( |
| 63 | + torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x |
| 64 | + ) |
| 65 | + |
| 66 | + # FX symbolic trace |
| 67 | + traced = torch.fx.symbolic_trace(model) |
| 68 | + # print(traced.graph) |
| 69 | + |
| 70 | + node_outputs = {} |
| 71 | + op_stats = {} |
| 72 | + for node in traced.graph.nodes: |
| 73 | + op_name = None |
| 74 | + dtype = None |
| 75 | + if node.op == "placeholder": |
| 76 | + node_outputs[node.name] = meta_input_dict[node.target] |
| 77 | + op_name = node.op |
| 78 | + dtype = node_outputs[node.name].dtype |
| 79 | + elif node.op in ["call_function", "call_method", "call_module"]: |
| 80 | + node_args = [] |
| 81 | + for arg in node.args: |
| 82 | + node_args.append( |
| 83 | + node_outputs[arg.name] if hasattr(arg, "name") else arg |
| 84 | + ) |
| 85 | + node_kwargs = {} |
| 86 | + for k, v in node.kwargs.items(): |
| 87 | + node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v |
| 88 | + |
| 89 | + if node.op == "call_module": |
| 90 | + # classname of module |
| 91 | + submod = dict(traced.named_modules())[node.target] |
| 92 | + op_name = submod.__class__.__name__ |
| 93 | + try: |
| 94 | + out = submod(*node_args, **node_kwargs) |
| 95 | + node_outputs[node.name] = out |
| 96 | + dtype = out.dtype if isinstance(out, torch.Tensor) else None |
| 97 | + except Exception: |
| 98 | + node_outputs[node.name] = None |
| 99 | + elif node.op in ["call_function", "call_method"]: |
| 100 | + op_name = ( |
| 101 | + node.target.__name__ if node.op == "call_function" else node.target |
| 102 | + ) |
| 103 | + try: |
| 104 | + out = node.target(*node_args, **node_kwargs) |
| 105 | + node_outputs[node.name] = out |
| 106 | + dtype = out.dtype if isinstance(out, torch.Tensor) else None |
| 107 | + except Exception: |
| 108 | + print(f"dtype inference failed: op_name={op_name}") |
| 109 | + node_outputs[node.name] = None |
| 110 | + elif node.op == "output": |
| 111 | + op_name = node.op |
| 112 | + node_args = [] |
| 113 | + for arg in node.args: |
| 114 | + node_args.append( |
| 115 | + node_outputs[arg.name] if hasattr(arg, "name") else arg |
| 116 | + ) |
| 117 | + node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args |
| 118 | + dtype = ( |
| 119 | + node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None |
| 120 | + ) |
| 121 | + else: |
| 122 | + assert False, f"node.op: {node.op}" |
| 123 | + |
| 124 | + if op_name is not None: |
| 125 | + dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None |
| 126 | + if op_stats.get(op_name, None) is None: |
| 127 | + op_stats[op_name] = OpStat(op_name, {dtype_str}, 1) |
| 128 | + else: |
| 129 | + op_stats[op_name].dtype.add(dtype_str) |
| 130 | + op_stats[op_name].count = op_stats[op_name].count + 1 |
| 131 | + return op_stats |
| 132 | + |
| 133 | + |
| 134 | +def collect_model_stats(model_path, device, log_prompt): |
| 135 | + print(f"Collect information for {model_path}") |
| 136 | + model_class = load_class_from_file( |
| 137 | + os.path.join(model_path, "model.py"), "GraphModule" |
| 138 | + ) |
| 139 | + model = model_class() |
| 140 | + input_dict = get_input_dict(model_path, device) |
| 141 | + |
| 142 | + num_ops = 0 |
| 143 | + num_inputs = 0 |
| 144 | + num_outputs = 0 |
| 145 | + dtypes = set() |
| 146 | + op_stats = collect_op_stats(model, input_dict) |
| 147 | + for op_name, stat in op_stats.items(): |
| 148 | + if op_name == "placeholder": |
| 149 | + num_inputs += stat.count |
| 150 | + elif op_name == "output": |
| 151 | + num_outputs += stat.count |
| 152 | + else: |
| 153 | + num_ops += stat.count |
| 154 | + for v in stat.dtype: |
| 155 | + if v is not None: |
| 156 | + dtypes.add(v) |
| 157 | + |
| 158 | + arg_types = get_argument_types(model_class, "forward") |
| 159 | + num_params = 0 |
| 160 | + param_dtypes = set() |
| 161 | + for name, arg_type in arg_types.items(): |
| 162 | + if arg_type == torch.nn.parameter.Parameter: |
| 163 | + count = math.prod(input_dict[name].shape) |
| 164 | + # print(f"Parameter {name}: {count}") |
| 165 | + num_params += count |
| 166 | + param_dtypes.add(str(input_dict[name].dtype).replace("torch.", "")) |
| 167 | + num_params_in_billion = num_params / 1e9 |
| 168 | + |
| 169 | + dtypes_str = "[" + ",".join(dtypes) + "]" |
| 170 | + param_dtypes_str = "[" + ",".join(param_dtypes) + "]" |
| 171 | + print( |
| 172 | + f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str}", |
| 173 | + file=sys.stderr, |
| 174 | + flush=True, |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +def main(args): |
| 179 | + if args.model_path is not None: |
| 180 | + assert os.path.isdir(args.model_path) |
| 181 | + assert is_single_model_dir(args.model_path) |
| 182 | + collect_model_stats(args.model_path, args.device, args.log_prompt) |
| 183 | + else: |
| 184 | + graph_net_samples_path = ( |
| 185 | + (graph_net.torch.samples_util.get_default_samples_directory()) |
| 186 | + if args.graph_net_samples_path is None |
| 187 | + else args.graph_net_samples_path |
| 188 | + ) |
| 189 | + for root, dirs, files in os.walk(graph_net_samples_path): |
| 190 | + if is_single_model_dir(root): |
| 191 | + collect_model_stats(root, args.device, args.log_prompt) |
| 192 | + |
| 193 | + |
| 194 | +if __name__ == "__main__": |
| 195 | + parser = argparse.ArgumentParser( |
| 196 | + description="Validate a computation graph sample. return 0 if success" |
| 197 | + ) |
| 198 | + parser.add_argument( |
| 199 | + "--device", |
| 200 | + type=str, |
| 201 | + required=False, |
| 202 | + default="cuda", |
| 203 | + help="Device for testing the compiler (e.g., 'cpu' or 'cuda')", |
| 204 | + ) |
| 205 | + parser.add_argument( |
| 206 | + "--model-path", |
| 207 | + type=str, |
| 208 | + required=False, |
| 209 | + default=None, |
| 210 | + help="Computation graph sample directory. e.g '../../samples/torch/resnet18'", |
| 211 | + ) |
| 212 | + parser.add_argument( |
| 213 | + "--graph-net-samples-path", |
| 214 | + type=str, |
| 215 | + required=False, |
| 216 | + default=None, |
| 217 | + help="GraphNet samples directory. e.g '../../samples'", |
| 218 | + ) |
| 219 | + parser.add_argument( |
| 220 | + "--log-prompt", |
| 221 | + type=str, |
| 222 | + required=False, |
| 223 | + default="graph-net-collect-stats-log", |
| 224 | + help="Log prompt for stats log filtering.", |
| 225 | + ) |
| 226 | + args = parser.parse_args() |
| 227 | + main(args=args) |
0 commit comments