@@ -110,20 +110,18 @@ def _test_dtype_passes(
110110 traced_model: Traced FX GraphModule
111111
112112 Returns:
113- List of pass names that work
113+ List of pass names that work (pass file names without .py extension)
114114 """
115115 working_passes = []
116116
117117 for dtype in self .dtype_list :
118- pass_name = "dtype_generalization_pass" # Full pass file name without .py
118+ # Pass name directly corresponds to file name (without .py)
119+ pass_name = f"dtype_generalization_pass_{ dtype } "
119120
120121 try :
121- # Try to apply the pass
122+ # Try to load and apply the pass
122123 dtype_pass_class = get_dtype_generalization_pass (pass_name )
123- dtype_pass = dtype_pass_class (
124- target_dtype = dtype ,
125- preserve_weights = FLOAT32_PRESERVED_WEIGHTS ,
126- )
124+ dtype_pass = dtype_pass_class ()
127125
128126 # Check if pass is needed
129127 if not dtype_pass .need_rewrite (traced_model ):
@@ -135,11 +133,11 @@ def _test_dtype_passes(
135133
136134 # Try to run the modified graph
137135 if self ._test_graph_runnable (model_path , gm_copy , dtype ):
138- working_passes .append (f" { pass_name } _ { dtype } " )
139- logging .info (f"Pass { pass_name } _ { dtype } works for { model_path } " )
136+ working_passes .append (pass_name )
137+ logging .info (f"Pass { pass_name } works for { model_path } " )
140138
141139 except (RuntimeError , ValueError , TypeError ) as e :
142- logging .warning (f"Pass { pass_name } _ { dtype } failed: { e } " )
140+ logging .warning (f"Pass { pass_name } failed: { e } " )
143141 continue
144142
145143 return working_passes
@@ -302,37 +300,25 @@ def _apply_pass_and_generate(
302300 Args:
303301 model_path: Original model path
304302 traced_model: Original traced model
305- pass_name: Name of the pass to apply (e.g., "dtype_generalization_pass_float16")
303+ pass_name: Name of the pass file (without .py extension),
304+ e.g., "dtype_generalization_pass_float16"
306305
307306 Returns:
308307 Path to the generated sample directory
309308 """
310- # Parse pass name to extract base name and dtype
311- # Format: "dtype_generalization_pass_float16" or "dtype_generalization_pass_bfloat16"
312- # The base name "dtype_generalization_pass" corresponds to the file
313- # dtype_generalization_pass.py, which contains the ConcretePass class.
314- parts = pass_name .rsplit ("_" , 1 )
315- if len (parts ) != 2 :
309+ # Pass name directly corresponds to file name (without .py)
310+ # Extract dtype from pass name for output directory naming
311+ if not pass_name .startswith ("dtype_generalization_pass_" ):
316312 raise ValueError (
317- f"Invalid pass name format : { pass_name } . "
313+ f"Invalid pass name: { pass_name } . "
318314 f"Expected format: 'dtype_generalization_pass_<dtype>'"
319315 )
320316
321- base_name , dtype = parts
322-
323- # Validate base name
324- if base_name != "dtype_generalization_pass" :
325- raise ValueError (
326- f"Unknown pass base name: { base_name } . "
327- f"Expected: 'dtype_generalization_pass'"
328- )
317+ dtype = pass_name .replace ("dtype_generalization_pass_" , "" )
329318
330319 # Load and apply the pass
331- dtype_pass_class = get_dtype_generalization_pass (base_name )
332- dtype_pass = dtype_pass_class (
333- target_dtype = dtype ,
334- preserve_weights = FLOAT32_PRESERVED_WEIGHTS ,
335- )
320+ dtype_pass_class = get_dtype_generalization_pass (pass_name )
321+ dtype_pass = dtype_pass_class ()
336322
337323 gm_copy = copy .deepcopy (traced_model )
338324 gm_modified = dtype_pass .rewrite (gm_copy )
0 commit comments