@@ -99,7 +99,9 @@ def inference_cost(
9999 :param preprocess: If set, run preprocessing steps such as shape inference,
100100 datatype inference and constant folding. Strongly recommended.
101101 :param discount_sparsity: If set, will discount op cost of MAC ops with a
102- constant zero weight, and the mem cost of constant zero weights."""
102+ constant zero weight, and the mem cost of constant zero weights.
103+ :param cost_breakdown: If set, include per-node (by name) and per-node-type
104+ breakdowns as part of the returned inference cost dict."""
103105
104106 combined_results = {}
105107 if isinstance (model_filename_or_wrapper , ModelWrapper ):
@@ -130,26 +132,19 @@ def inference_cost(
130132 res ["total_macs" ] = macs
131133 if "unsupported" in res :
132134 res ["unsupported" ] = str (res ["unsupported" ])
133- if output_json is not None :
134- with open (output_json , "w" ) as f :
135- json .dump (res , f , sort_keys = True , indent = 2 )
136135 combined_results [i ] = res
137- elif i == "optype_cost" :
138- per_optype_breakdown = {}
136+ else :
137+ per_optype_or_node_breakdown = {}
139138 for optype , op_res in res .items ():
140139 bops , macs = compute_bops_and_macs (op_res )
141140 op_res = assign_mem_bits_and_elems (op_res )
142141 op_res ["total_bops" ] = bops
143142 op_res ["total_macs" ] = macs
144- per_optype_breakdown [optype ] = op_res
145- combined_results [i ] = per_optype_breakdown
146- else :
147- per_node_breakdown = {}
148- for node_name in res .keys ():
149- node_res = res [node_name ]
150- node_res = assign_mem_bits_and_elems (node_res )
151- per_node_breakdown [node_name ] = node_res
152- combined_results [i ] = per_node_breakdown
143+ per_optype_or_node_breakdown [optype ] = op_res
144+ combined_results [i ] = per_optype_or_node_breakdown
145+ if output_json is not None :
146+ with open (output_json , "w" ) as f :
147+ json .dump (combined_results , f , sort_keys = True , indent = 2 )
153148 return combined_results
154149
155150
0 commit comments