|
7 | 7 | from __future__ import print_function
|
8 | 8 |
|
9 | 9 | import argparse
|
| 10 | +from collections import Counter |
10 | 11 | import logging
|
11 | 12 | import os
|
12 | 13 | import sys
|
@@ -134,6 +135,19 @@ def get_graph_io_nodes(input_path):
|
134 | 135 | logging.info("\t%s inputs: %s", len(inputs), ','.join(inputs))
|
135 | 136 | logging.info("\t%s (possible) outputs: %s", len(inputs), ','.join(outputs))
|
136 | 137 |
|
| 138 | + @staticmethod |
| 139 | + def print_graph_stat(input_path): |
| 140 | + logging.info("load from %s", input_path) |
| 141 | + graph_def = load_graph_def_from_pb(input_path) |
| 142 | + |
| 143 | + op_stat = Counter() |
| 144 | + for node in graph_def.node: |
| 145 | + op_stat[node.op] += 1 |
| 146 | + |
| 147 | + logging.info("graph stat:") |
| 148 | + for op, count in sorted(op_stat.items(), key=lambda x: x[0]): |
| 149 | + logging.info("\t%s = %s", op, count) |
| 150 | + |
137 | 151 | @staticmethod
|
138 | 152 | def extract_sub_graph(input_path, output_path=None, dest_nodes=None):
|
139 | 153 | if not output_path:
|
@@ -178,6 +192,11 @@ def extract_sub_graph(input_path, output_path=None, dest_nodes=None):
|
178 | 192 | subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
|
179 | 193 | subparser.set_defaults(func=main.get_graph_io_nodes)
|
180 | 194 |
|
| 195 | + # stat |
| 196 | + subparser = subparsers.add_parser("stat", help="print stat") |
| 197 | + subparser.add_argument("--input", dest="input_path", required=True, help="input pb path") |
| 198 | + subparser.set_defaults(func=main.print_graph_stat) |
| 199 | + |
181 | 200 | # extract
|
182 | 201 | subparser = subparsers.add_parser("extract", help="extract sub-graph")
|
183 | 202 | subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
|
|
0 commit comments