diff --git a/graph_net/collect_stats_util.py b/graph_net/collect_stats_util.py new file mode 100644 index 000000000..fe98579d8 --- /dev/null +++ b/graph_net/collect_stats_util.py @@ -0,0 +1,100 @@ +import ast +import json +import importlib +import inspect +from dataclasses import dataclass, field +from typing import Dict + + +@dataclass +class OpStat: + op_name: str + op_dtypes: dict[str, int] = field(default_factory=dict) + count: int = 0 + + def update(self, other): + if isinstance(other, OpStat) and self.op_name == other.op_name: + self.count += other.count + for name, count in other.op_dtypes.items(): + self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count + + +@dataclass +class ModelStats: + model_path: str + num_inputs: int = None + num_params: int = None + num_outputs: int = None + num_ops: int = None + model_size_in_billion: float = None + input_dtypes: Dict[str, int] = field(default_factory=dict) + param_dtypes: Dict[str, int] = field(default_factory=dict) + input_shapes: Dict[str, list] = field(default_factory=dict) + op_dtypes: Dict[str, int] = field(default_factory=dict) + ops: Dict[str, int] = field(default_factory=dict) + source: str = None + heuristic_tag: str = None + + +def print_model_stats(stats, log_prompt): + assert isinstance(stats, ModelStats), f"{type(stats)=}" + + def print_with_log_prompt(key, value): + print( + f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}", + flush=True, + ) + + print_with_log_prompt("num_inputs", stats.num_inputs) + print_with_log_prompt("num_params", stats.num_params) + print_with_log_prompt("num_outputs", stats.num_outputs) + print_with_log_prompt("num_ops", stats.num_ops) + print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B") + print_with_log_prompt("input_dtypes", json.dumps(stats.input_dtypes)) + print_with_log_prompt("param_dtypes", json.dumps(stats.param_dtypes)) + print_with_log_prompt("input_shapes", json.dumps(stats.input_shapes)) + print_with_log_prompt("op_dtypes", json.dumps(stats.op_dtypes)) + print_with_log_prompt("ops", json.dumps(stats.ops)) + print_with_log_prompt("source", stats.source) + print_with_log_prompt("heuristic_tag", stats.heuristic_tag) + + +def load_class_from_file(file_path, class_name): + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def get_argument_name_and_types(model_class, func_name): + argument_name2types = {} + for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): + if name == func_name: + for arg_name, arg in inspect.signature(func).parameters.items(): + if arg_name != "self": + argument_name2types[arg_name] = ( + None if arg.annotation is inspect._empty else arg.annotation + ) + return argument_name2types + + +def get_number_of_returns(file_path, class_name, func_name): + source = None + with open(f"{file_path}", "r") as f: + source = f.read() + + tree = ast.parse(source) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for f in node.body: + if isinstance(f, ast.FunctionDef) and f.name == func_name: + for stmt in ast.walk(f): + if isinstance(stmt, ast.Return): + if stmt.value is None: + return 0 + elif isinstance(stmt.value, ast.Tuple): + return len(stmt.value.elts) + else: + return 1 + return 0 diff --git a/graph_net/paddle/collect_stats.py b/graph_net/paddle/collect_stats.py new file mode 100644 index 000000000..34f9c366d --- /dev/null +++ b/graph_net/paddle/collect_stats.py @@ -0,0 +1,301 @@ +import argparse +import os +import re +import sys +import math +import subprocess +from datetime import datetime + +import paddle +from graph_net import collect_stats_util +from graph_net.paddle import utils + + +def is_single_model_dir(model_dir): + return os.path.isfile(f"{model_dir}/graph_net.json") + + +def read_graph_source_and_tag(model_path): + try: + with open(os.path.join(model_path, "graph_net.json"), "r") as f: + data = json.load(f) + return data["source"], data["heuristic_tag"] + except Exception: + if "PaddleX" in model_path: + return "PaddleX", "computer_vision" + elif "PaddleNLP" in model_path: + return "PaddleNLP", "nlp" + elif "PaddleScience" in model_path: + return "PaddleScience", "scientific_computing" + else: + return "unknown", "unknown" + + +def get_input_spec(model_path): + inputs_params_list = utils.load_converted_list_from_text(f"{model_path}") + input_spec = [None] * len(inputs_params_list) + for i, v in enumerate(inputs_params_list): + dtype = v["info"]["dtype"] + shape = v["info"]["shape"] + input_spec[i] = paddle.static.InputSpec(shape, dtype) + return input_spec + + +class ProgramAnalyzer: + def __init__(self): + self.op_stats = {} + self.input_dict = {} + self.num_ops = 0 + self.num_ops_misses_dtypes = 0 + self.is_complete = True + + def update_op_stats(self, op_name, op_dtype): + if op_name is not None: + dtype_str = str(op_dtype).replace("paddle.", "") + if self.op_stats.get(op_name, None) is None: + self.op_stats[op_name] = collect_stats_util.OpStat( + op_name, {dtype_str: 1}, 1 + ) + else: + self.op_stats[op_name].op_dtypes[dtype_str] = ( + self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 + ) + self.op_stats[op_name].count += 1 + + def parse_pir_value_dtypes(self, type_str): + short_form2dtype = { + "f32": "float32", + "f16": "float16", + "bf16": "bfloat16", + "i64": "int64", + } + # type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]" + matches = re.findall(r"tensor<([^>]+)>", type_str) + dtype_strs = [] + for s in matches: + parts = s.split("x") + assert len(parts) > 0 + + dtype = parts[-1].lower() + dtype_strs.append(short_form2dtype[dtype]) + return dtype_strs + + def __call__(self, program): + assert isinstance(program, paddle.base.libpaddle.pir.Program) + + self.op_stats = {} + self.num_ops_misses_dtypes = 0 + self.num_ops = 0 + for block in program.blocks: + for op in block.ops: + op_name = None + op_dtype = None + if op.name() == "pd_op.data": + op_name = "data" + op_attrs = op.attrs() + op_dtype = op_attrs["dtype"] + self.input_dict[op_attrs["name"]] = { + "dtype": str(op_dtype).replace("paddle.", ""), + "shape": op_attrs["shape"], + } + elif op.name().startswith("pd_op."): + self.num_ops += 1 + op_name = op.name().replace("pd_op.", "") + try: + if len(op.results()) > 0: + out = op.results()[0] + if out.is_dense_tensor_type(): + op_dtype = out.dtype + else: + # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined + if op_name in [ + "broadcast_tensors", + "distribute_fpn_proposals", + "meshgrid", + "split", + "split_with_num", + ]: + op_dtype = self.parse_pir_value_dtypes( + str(out.type()) + )[0] + else: + assert False, f"Unsupport op: {op}" + except Exception: + if self.num_ops_misses_dtypes == 0: + print(f"dtype inference failed for {op_name}") + if op_dtype is not None: + self.update_op_stats(op_name, op_dtype) + else: + self.num_ops_misses_dtypes += 1 + elif not op.name().startswith("builtin."): + assert False, f"Unrecognized op: {op}" + + if self.num_ops_misses_dtypes > 0: + self.is_complete = False + + def summary(self): + print( + f"Totally {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes." + ) + + +def collect_op_stats(model, model_path): + assert isinstance(model, paddle.nn.Layer), f"{type(model)=}" + try: + static_model = paddle.jit.to_static( + model, + input_spec=get_input_spec(model_path), + full_graph=True, + backend=None, + ) + static_model.eval() + program = static_model.forward.concrete_program.main_program + + program_analyzer = ProgramAnalyzer() + program_analyzer(program) + program_analyzer.summary() + return program_analyzer + except Exception: + print("Failed with to_static") + return None + + +def collect_model_stats(model_path, log_prompt): + file_path = os.path.join(model_path, "model.py") + model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule") + model = model_class() + + model_size = 0 + input_shapes = set() + input_dtypes = {} + param_dtypes = {} + ops_count_dict = {} + op_dtypes = {} + + program_analyzer = collect_op_stats(model, model_path) + if program_analyzer is not None: + for op_name, stat in sorted(program_analyzer.op_stats.items()): + ops_count_dict[op_name] = stat.count + for dtype_str, num in stat.op_dtypes.items(): + if dtype_str is not None and dtype_str != "None": + op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num + + inputs_params = utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + inputs = inputs_params["input_info"] + + for name, value in program_analyzer.input_dict.items(): + dtype_str = value["dtype"] + if name in params.keys(): + param_numel = math.prod(value["shape"]) + model_size += param_numel + param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 + elif name in inputs.keys(): + input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + input_shapes.add(str(value["shape"])) + + num_outputs = collect_stats_util.get_number_of_returns( + file_path, "GraphModule", "forward" + ) + num_ops = program_analyzer.num_ops if program_analyzer is not None else 0 + source, heuristic_tag = read_graph_source_and_tag(model_path) + is_complete = ( + program_analyzer.is_complete if program_analyzer is not None else False + ) + print( + f"model_stats collection information: model_path={model_path} method=to_static is_ops_complete={is_complete}" + ) + + stats = collect_stats_util.ModelStats( + model_path=model_path, + num_inputs=sum(input_dtypes.values()), + num_params=sum(param_dtypes.values()), + num_outputs=num_outputs, + num_ops=num_ops, + model_size_in_billion=model_size / 1e9, + input_dtypes=input_dtypes, + param_dtypes=param_dtypes, + input_shapes=list(input_shapes), + op_dtypes=op_dtypes, + ops=ops_count_dict, + source=source, + heuristic_tag=heuristic_tag, + ) + collect_stats_util.print_model_stats(stats, log_prompt) + + +def main(args): + if args.model_path is not None: + assert os.path.isdir(args.model_path) + assert is_single_model_dir(args.model_path) + timestamp_sec = datetime.now().timestamp() + dt = datetime.fromtimestamp(timestamp_sec) + formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + print(f"[{formatted_dt}] Collect information for {args.model_path}") + collect_model_stats(args.model_path, args.log_prompt) + else: + graph_net_samples_path = ( + (graph_net.paddle.samples_util.get_default_samples_directory()) + if args.graph_net_samples_path is None + else args.graph_net_samples_path + ) + + i = 0 + for root, dirs, files in os.walk(graph_net_samples_path): + if is_single_model_dir(root): + print(f"[{i}] Collect information for {root}") + cmd = [ + "python", + "-m", + "graph_net.paddle.collect_stats", + f"--device={args.device}", + f"--model-path={root}", + f"--log-prompt={args.log_prompt}", + ] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=600, + ) + print(result.stdout) + if result.returncode != 0: + print(result.stderr) + i += 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Collect stats for computation graph samples. return 0 if success" + ) + parser.add_argument( + "--device", + type=str, + required=False, + default="cuda", + help="Device for testing the compiler (e.g., 'cpu' or 'cuda')", + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Computation graph sample directory. e.g '../../paddle_samples/PaddleX/ResNet18'", + ) + parser.add_argument( + "--graph-net-samples-path", + type=str, + required=False, + default=None, + help="GraphNet samples directory. e.g '../../paddle_samples'", + ) + parser.add_argument( + "--log-prompt", + type=str, + required=False, + default="graph-net-collect-stats-log", + help="Log prompt for stats log filtering.", + ) + args = parser.parse_args() + main(args=args) diff --git a/graph_net/paddle/validate.py b/graph_net/paddle/validate.py index 9570d6cc2..bd9c9e377 100644 --- a/graph_net/paddle/validate.py +++ b/graph_net/paddle/validate.py @@ -36,8 +36,6 @@ def _extract_forward_source(model_path, class_name): source = f.read() tree = ast.parse(source) - forward_code = None - for node in tree.body: if isinstance(node, ast.ClassDef) and node.name == class_name: for fn in node.body: diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py new file mode 100644 index 000000000..78a6502a0 --- /dev/null +++ b/graph_net/torch/collect_stats.py @@ -0,0 +1,460 @@ +import argparse +import os +import sys +import math +import json +import subprocess +from datetime import datetime + +import torch +from graph_net import collect_stats_util +from graph_net.torch import utils + + +def is_single_model_dir(model_dir): + return os.path.isfile(f"{model_dir}/graph_net.json") + + +def read_graph_source_and_tag(model_path): + try: + with open(os.path.join(model_path, "graph_net.json"), "r") as f: + data = json.load(f) + return data["source"], data["heuristic_tag"] + except Exception: + return "unknown", "unknown" + + +def get_input_dict(model_path, device): + inputs_params = utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if hasattr(tensor_meta, "device"): + tensor_meta.device = device + return { + k: utils.replay_tensor(v).to(torch.device(device)) for k, v in params.items() + } + + +def resolve_native_multi_head_attention(*args, **kwargs): + query, key, value = args[0], args[1], args[2] + seq_len, batch_size, embed_dim = query.shape + attn_output = torch.empty( + (seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta" + ) + + # TODO(Xreki): get value from args + need_weights = False + if need_weights: + seq_len_k = key.shape[0] + num_heads = args[4] + attn_output_weights = torch.empty( + (batch_size, num_heads, seq_len, seq_len_k), + dtype=query.dtype, + device="meta", + ) + return attn_output, attn_output_weights + else: + return attn_output + + +def resolve_tensor_to(tensor, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.dtype): + dtype = args[0] + else: + dtype = tensor.dtype + return torch.empty(tensor.shape, dtype=dtype, device="meta") + + +def resolve_tensor_item(tensor): + return torch.empty((), dtype=tensor.dtype, device="meta") + + +def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): + attr_itr = node.target.split(".") + val = gm + for a in attr_itr: + val = getattr(val, a) + out = val.to(device="meta") if isinstance(val, torch.Tensor) else val + return out + + +def convert_real_to_meta(x): + if isinstance(x, torch.Tensor) and not x.is_meta: + return torch.empty_like(x, device="meta") + elif isinstance(x, (list, tuple)): + return type(x)(convert_real_to_meta(v) for v in x) + elif isinstance(x, dict): + return {k: convert_real_to_meta(v) for k, v in x.items()} + else: + return x + + +def convert_meta_to_real(x, device): + if isinstance(x, torch.Tensor) and x.is_meta: + return torch.empty_like(x, device=device) + elif isinstance(x, (list, tuple)): + return type(x)(convert_meta_to_real(v, device) for v in x) + elif isinstance(x, dict): + return {k: convert_meta_to_real(v, device) for k, v in x.items()} + else: + return x + + +def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs): + try: + real_args = convert_meta_to_real(meta_args, device) + real_kwargs = convert_meta_to_real(meta_kwargs, device) + + real_out = op_func(*real_args, **real_kwargs) + return convert_real_to_meta(real_out) + except Exception: + return None + + +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True +torch._dynamo.config.capture_sparse_compute = True +torch._dynamo.config.raise_on_ctx_manager_usage = False +torch._dynamo.config.allow_rnn = True + + +class GraphMetaExecutor: + def __init__(self, device): + self.device = device + self.op_stats = {} + self.is_complete = True + self.num_ops = 0 + self.num_ops_misses_dtypes = 0 + self.subgraph_counter = 0 + + def get_output_dtype(self, out): + if isinstance(out, torch.Tensor): + return out.dtype + if ( + isinstance(out, (list, tuple)) + and len(out) > 0 + and isinstance(out[0], torch.Tensor) + ): + return out[0].dtype + else: + return None + + def get_op_name_and_func(self, gm, node, node_outputs): + op_name = None + op_func = None + try: + if node.op == "call_module": + # classname of module + submod = gm.get_submodule(node.target) + op_name = submod.__class__.__name__ + op_func = submod + elif node.op == "call_function": + op_name = node.target.__name__ + op_func = node.target + elif node.op == "call_method": + op_name = node.target + self_obj = ( + node_outputs[node.args[0].name] + if isinstance(node.args[0], torch.fx.Node) + else node.args[0] + ) + op_func = getattr(self_obj, node.target) + elif node.op in ["get_attr", "placeholder", "output"]: + op_name = node.op + except Exception: + pass + return op_name, op_func + + def update_op_stats(self, op_stats, op_name, op_dtype): + if op_name is not None: + dtype_str = str(op_dtype).replace("torch.", "") + if op_stats.get(op_name, None) is None: + op_stats[op_name] = collect_stats_util.OpStat( + op_name, {dtype_str: 1}, 1 + ) + else: + op_stats[op_name].op_dtypes[dtype_str] = ( + op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 + ) + op_stats[op_name].count += 1 + + def __call__(self, gm: torch.fx.GraphModule, sample_inputs): + # Use meta tensors as input to avoid actually running the model + meta_sample_inputs = convert_real_to_meta(sample_inputs) + + op_stats = {} + num_ops_misses_dtypes = 0 + + input_idx = 0 + node_outputs = {} + for node in gm.graph.nodes: + out = None + op_dtype = None + op_name, op_func = self.get_op_name_and_func(gm, node, node_outputs) + if node.op == "placeholder": + out = meta_sample_inputs[input_idx] + input_idx += 1 + elif node.op in ["call_function", "call_module", "call_method"]: + try: + node_args = torch.fx.map_arg( + node.args, + lambda n: node_outputs[n.name] + if isinstance(n, torch.fx.Node) + else n, + ) + node_kwargs = torch.fx.map_arg( + node.kwargs, + lambda n: node_outputs[n.name] + if isinstance(n, torch.fx.Node) + else n, + ) + if node.op == "call_method": + node_args = node_args[1:] + + if op_name == "_native_multi_head_attention": + out = resolve_native_multi_head_attention( + *node_args, **node_kwargs + ) + elif op_name == "to": + out = resolve_tensor_to( + node_outputs[node.args[0].name], *node_args, **node_kwargs + ) + elif op_name == "item": + out = resolve_tensor_item(node_outputs[node.args[0].name]) + else: + assert op_func is not None, f"op_func of {node} is None." + out = op_func(*node_args, **node_kwargs) + except Exception: + out = resolve_with_real_tensor( + op_func, self.device, node_args, node_kwargs + ) + if out is None: + if num_ops_misses_dtypes == 0: + print( + f"dtype inference failed: node.op={node.op}, op_name={op_name}" + ) + num_ops_misses_dtypes += 1 + elif node.op == "get_attr": + out = resolve_get_attr(gm, node) + elif node.op == "output": + pass + else: + assert False, f"node.op: {node.op}" + + if out is not None: + node_outputs[node.name] = out + op_dtype = self.get_output_dtype(out) + + if node.op not in ["placeholder", "output"]: + self.update_op_stats(op_stats, op_name, op_dtype) + + if num_ops_misses_dtypes > 0: + self.is_complete = False + self.num_ops_misses_dtypes += num_ops_misses_dtypes + num_ops = 0 + for name, stat in op_stats.items(): + num_ops += stat.count + if name in self.op_stats.keys(): + self.op_stats[name].update(stat) + else: + self.op_stats[name] = stat + self.num_ops += num_ops + self.subgraph_counter += 1 + return gm.forward + + def summary(self): + print( + f"Totally {self.subgraph_counter} subgraphs, {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes." + ) + + +def collect_op_stats_with_compile(model, sample_inputs, device): + assert isinstance(model, torch.nn.Module), f"{type(model)=}" + try: + meta_executor = GraphMetaExecutor(device) + compiled_model = torch.compile(model, backend=meta_executor) + compiled_model(*sample_inputs) + meta_executor.summary() + return meta_executor + except Exception: + print("Failed with torch.compile") + return None + + +def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): + assert isinstance(model, torch.nn.Module), f"{type(model)=}" + try: + # FX symbolic trace + traced = torch.fx.symbolic_trace(model) + except Exception: + print("Failed with symbolic_trace") + return None + + meta_executor = GraphMetaExecutor(device) + meta_executor(traced, sample_inputs) + meta_executor.summary() + return meta_executor + + +def collect_op_stats(model, sample_inputs, device): + meta_executor_symbolic = collect_op_stats_with_symbolic_trace( + model, sample_inputs, device + ) + if meta_executor_symbolic is None or not meta_executor_symbolic.is_complete: + meta_executor_compile = collect_op_stats_with_compile( + model, sample_inputs, device + ) + if meta_executor_symbolic is None or ( + meta_executor_compile is not None and meta_executor_compile.is_complete + ): + return "torch.compile", meta_executor_compile + return "symbolic_trace", meta_executor_symbolic + + +def collect_model_stats(model_path, device, log_prompt): + file_path = os.path.join(model_path, "model.py") + model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule") + model = model_class() + argument_name2types = collect_stats_util.get_argument_name_and_types( + model_class, "forward" + ) + + input_dict = get_input_dict(model_path, device) + ordered_input_list = [ + input_dict[arg_name] for arg_name in argument_name2types.keys() + ] + + ops_count_dict = {} + op_dtypes = {} + method, meta_executor = collect_op_stats(model, ordered_input_list, device) + if meta_executor is not None: + for op_name, stat in sorted(meta_executor.op_stats.items()): + if op_name not in ["placeholder", "output"]: + ops_count_dict[op_name] = stat.count + for dtype_str, num in stat.op_dtypes.items(): + if dtype_str is not None and dtype_str != "None": + op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num + + model_size = 0 + input_shapes = set() + input_dtypes = {} + param_dtypes = {} + for name, arg_type in argument_name2types.items(): + if ( + name.startswith("L_self_modules_") + or arg_type == torch.nn.parameter.Parameter + ): + # Some parameters like L_self_modules_bn1_buffers_running_mean_ are torch.Tensor. + param_numel = math.prod(input_dict[name].shape) + model_size += param_numel + dtype_str = str(input_dict[name].dtype).replace("torch.", "") + param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 + elif arg_type == torch.Tensor: + dtype_str = str(input_dict[name].dtype).replace("torch.", "") + input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + input_shapes.add(str(list(input_dict[name].shape))) + + num_outputs = collect_stats_util.get_number_of_returns( + file_path, "GraphModule", "forward" + ) + num_ops = meta_executor.num_ops if meta_executor is not None else 0 + source, heuristic_tag = read_graph_source_and_tag(model_path) + + is_complete = meta_executor.is_complete if meta_executor is not None else False + print( + f"model_stats collection information: model_path={model_path} method={method} is_ops_complete={is_complete}" + ) + + stats = collect_stats_util.ModelStats( + model_path=model_path, + num_inputs=sum(input_dtypes.values()), + num_params=sum(param_dtypes.values()), + num_outputs=num_outputs, + num_ops=num_ops, + model_size_in_billion=model_size / 1e9, + input_dtypes=input_dtypes, + param_dtypes=param_dtypes, + input_shapes=list(input_shapes), + op_dtypes=op_dtypes, + ops=ops_count_dict, + source=source, + heuristic_tag=heuristic_tag, + ) + collect_stats_util.print_model_stats(stats, log_prompt) + + +def main(args): + if args.model_path is not None: + assert os.path.isdir(args.model_path) + assert is_single_model_dir(args.model_path) + timestamp_sec = datetime.now().timestamp() + dt = datetime.fromtimestamp(timestamp_sec) + formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + print(f"[{formatted_dt}] Collect information for {args.model_path}") + collect_model_stats(args.model_path, args.device, args.log_prompt) + else: + graph_net_samples_path = ( + (graph_net.torch.samples_util.get_default_samples_directory()) + if args.graph_net_samples_path is None + else args.graph_net_samples_path + ) + + i = 0 + for root, dirs, files in os.walk(graph_net_samples_path): + if is_single_model_dir(root): + print(f"[{i}] Collect information for {root}") + cmd = [ + "python", + "-m", + "graph_net.torch.collect_stats", + f"--device={args.device}", + f"--model-path={root}", + f"--log-prompt={args.log_prompt}", + ] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=600, + ) + print(result.stdout) + if result.returncode != 0: + print(result.stderr) + i += 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Validate a computation graph sample. return 0 if success" + ) + parser.add_argument( + "--device", + type=str, + required=False, + default="cuda", + help="Device for testing the compiler (e.g., 'cpu' or 'cuda')", + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Computation graph sample directory. e.g '../../samples/torch/resnet18'", + ) + parser.add_argument( + "--graph-net-samples-path", + type=str, + required=False, + default=None, + help="GraphNet samples directory. e.g '../../samples'", + ) + parser.add_argument( + "--log-prompt", + type=str, + required=False, + default="graph-net-collect-stats-log", + help="Log prompt for stats log filtering.", + ) + args = parser.parse_args() + main(args=args) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index 5922991c5..1348c4ddc 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -1,4 +1,3 @@ -from . import utils import argparse import importlib.util import inspect @@ -14,6 +13,8 @@ import json import numpy as np import platform + +from graph_net.torch import utils from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend from graph_net.torch.backend.tvm_backend import TvmBackend from graph_net.torch.backend.xla_backend import XlaBackend