11import argparse
22import os
3+ import re
34import sys
45import ast
56import math
@@ -118,6 +119,24 @@ def update_op_stats(self, op_name, op_dtype):
118119 )
119120 self .op_stats [op_name ].count += 1
120121
122+ def parse_pir_value_dtypes (self , type_str ):
123+ short_form2dtype = {
124+ "f32" : "float32" ,
125+ "f16" : "float16" ,
126+ "bf16" : "bfloat16" ,
127+ "i64" : "int64" ,
128+ }
129+ # type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]"
130+ matches = re .findall (r"tensor<([^>]+)>" , type_str )
131+ dtype_strs = []
132+ for s in matches :
133+ parts = s .split ("x" )
134+ assert len (parts ) > 0
135+
136+ dtype = parts [- 1 ].lower ()
137+ dtype_strs .append (short_form2dtype [dtype ])
138+ return dtype_strs
139+
121140 def __call__ (self , program ):
122141 assert isinstance (program , paddle .base .libpaddle .pir .Program )
123142
@@ -129,22 +148,38 @@ def __call__(self, program):
129148 op_name = None
130149 op_dtype = None
131150 if op .name () == "pd_op.data" :
151+ op_name = "data"
132152 op_attrs = op .attrs ()
133153 op_dtype = op_attrs ["dtype" ]
134154 self .input_dict [op_attrs ["name" ]] = {
135155 "dtype" : str (op_dtype ).replace ("paddle." , "" ),
136156 "shape" : op_attrs ["shape" ],
137157 }
138- elif not op .name ().startswith ("builtin ." ):
158+ elif op .name ().startswith ("pd_op ." ):
139159 self .num_ops += 1
140160 op_name = op .name ().replace ("pd_op." , "" )
141- if len (op .results ()) > 0 :
142- op_dtype = op .results ()[0 ].dtype
143-
144- if op_name is not None :
145- self .update_op_stats (op_name , op_dtype )
146- elif op_dtype is None :
147- self .num_ops_misses_dtypes += 1
161+ try :
162+ if len (op .results ()) > 0 :
163+ out = op .results ()[0 ]
164+ if out .is_dense_tensor_type ():
165+ op_dtype = out .dtype
166+ else :
167+ # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined
168+ if op_name in ["split" , "split_with_num" , "meshgrid" ]:
169+ op_dtype = self .parse_pir_value_dtypes (
170+ str (out .type ())
171+ )[0 ]
172+ else :
173+ assert False , f"Unsupport op: { op } "
174+ except Exception :
175+ if self .num_ops_misses_dtypes == 0 :
176+ print (f"dtype inference failed for { op_name } " )
177+ if op_dtype is not None :
178+ self .update_op_stats (op_name , op_dtype )
179+ else :
180+ self .num_ops_misses_dtypes += 1
181+ elif not op .name ().startswith ("builtin." ):
182+ assert False , f"Unrecognized op: { op } "
148183
149184 if self .num_ops_misses_dtypes > 0 :
150185 self .is_complete = False
@@ -281,7 +316,7 @@ def main(args):
281316 cmd = [
282317 "python" ,
283318 "-m" ,
284- "graph_net.torch .collect_stats" ,
319+ "graph_net.paddle .collect_stats" ,
285320 f"--device={ args .device } " ,
286321 f"--model-path={ root } " ,
287322 f"--log-prompt={ args .log_prompt } " ,
0 commit comments