1919
2020import torch .fx as fx
2121
22- from graph_net .graph_net_json_file_util import kDataTypeGeneralizationPasses
22+ from graph_net .graph_net_json_file_util import (
23+ kDataTypeGeneralizationPasses ,
24+ update_json ,
25+ )
2326from graph_net .torch .constraint_util import RunModelPredicator
2427from graph_net .torch .fx_graph_cache_util import (
2528 parse_immutable_model_path_into_sole_graph_module ,
2629)
2730from graph_net .torch .fx_graph_serialize_util import serialize_graph_module_to_str
2831from graph_net .torch .dtype_gen_passes .pass_mgr import get_dtype_generalization_pass
2932from graph_net .torch import utils
33+ from graph_net .imp_util import load_module
3034
3135
3236# Weights that must remain float32 for numerical stability
@@ -51,12 +55,14 @@ class InitDataTypeGeneralizationPasses:
5155 Config format:
5256 {
5357 "dtype_list": ["float16", "bfloat16"],
58+ "model_path_prefix": "",
5459 }
5560 """
5661
5762 def __init__ (self , config : Dict [str , Any ]):
5863 self .config = config
5964 self .dtype_list = config .get ("dtype_list" , ["float16" , "bfloat16" ])
65+ self .model_path_prefix = config .get ("model_path_prefix" , "" )
6066
6167 # Validate dtypes
6268 valid_dtypes = {"float16" , "bfloat16" , "float8" }
@@ -71,8 +77,12 @@ def __call__(self, model_path: str) -> None:
7177 Initialize dtype passes for the given model.
7278
7379 Args:
74- model_path: Path to the model directory
80+ model_path: Path to the model directory (may be relative to model_path_prefix)
7581 """
82+ # Apply model_path_prefix if provided
83+ if self .model_path_prefix :
84+ model_path = str (Path (self .model_path_prefix ) / model_path )
85+
7686 # Parse the computation graph
7787 traced_model = parse_immutable_model_path_into_sole_graph_module (model_path )
7888
@@ -176,18 +186,7 @@ def _save_dtype_pass_names(
176186 model_path: Path to model directory
177187 """
178188 graph_net_json_path = Path (model_path ) / "graph_net.json"
179-
180- with open (graph_net_json_path , "r" ) as f :
181- metadata = json .load (f )
182-
183- metadata [kDataTypeGeneralizationPasses ] = dtype_pass_names
184-
185- # Atomic write: write to temp file then rename
186- temp_path = graph_net_json_path .with_suffix (".json.tmp" )
187- with open (temp_path , "w" ) as f :
188- json .dump (metadata , f , indent = 4 )
189-
190- temp_path .replace (graph_net_json_path )
189+ update_json (graph_net_json_path , {kDataTypeGeneralizationPasses : dtype_pass_names })
191190
192191
193192class ApplyDataTypeGeneralizationPasses :
@@ -200,6 +199,10 @@ class ApplyDataTypeGeneralizationPasses:
200199 Config format:
201200 {
202201 "output_dir": "/path/to/output",
202+ "model_path_prefix": "",
203+ "model_runnable_predicator_filepath": "...",
204+ "model_runnable_predicator_class_name": "...",
205+ "model_runnable_predicator_config": {...},
203206 }
204207 """
205208
@@ -208,17 +211,41 @@ def __init__(self, config: Dict[str, Any]):
208211 self .output_dir = config .get ("output_dir" )
209212 if not self .output_dir :
210213 raise ValueError ("output_dir is required in config" )
214+
215+ self .model_path_prefix = config .get ("model_path_prefix" , "" )
216+
217+ # model_runnable_predicator is required to ensure generated code is runnable
218+ if "model_runnable_predicator_filepath" not in config :
219+ raise ValueError (
220+ "model_runnable_predicator_filepath is required in config. "
221+ "Generated code must be validated."
222+ )
223+ self .model_runnable_predicator = self ._make_model_runnable_predicator (config )
224+
225+ def _make_model_runnable_predicator (self , config : Dict [str , Any ]):
226+ """Create model runnable predicator from config."""
227+ module = load_module (config ["model_runnable_predicator_filepath" ])
228+ cls = getattr (
229+ module ,
230+ config .get ("model_runnable_predicator_class_name" , "RunModelPredicator" ),
231+ )
232+ predicator_config = config .get ("model_runnable_predicator_config" , {})
233+ return cls (predicator_config )
211234
212235 def __call__ (self , model_path : str ) -> List [str ]:
213236 """
214237 Apply dtype passes to generate new samples.
215238
216239 Args:
217- model_path: Path to the original model directory
240+ model_path: Path to the original model directory (may be relative to model_path_prefix)
218241
219242 Returns:
220243 List of generated sample directories
221244 """
245+ # Apply model_path_prefix if provided
246+ if self .model_path_prefix :
247+ model_path = str (Path (self .model_path_prefix ) / model_path )
248+
222249 # Read pass names from graph_net.json
223250 dtype_pass_names = self ._read_dtype_pass_names (model_path )
224251
@@ -316,6 +343,13 @@ def _apply_pass_and_generate(
316343 # Update graph_net.json with dtype information
317344 self ._update_sample_metadata (output_sample_dir , dtype )
318345
346+ # Validate generated sample (required - generated code must be runnable)
347+ if not self .model_runnable_predicator (str (output_sample_dir )):
348+ raise RuntimeError (
349+ f"Generated sample failed validation: { output_sample_dir } "
350+ )
351+ logging .info (f"Generated sample validated: { output_sample_dir } " )
352+
319353 return str (output_sample_dir )
320354
321355 def _update_sample_metadata (self , sample_dir : Path , dtype : str ) -> None :
@@ -327,20 +361,14 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
327361 dtype: Target dtype
328362 """
329363 graph_net_json_path = sample_dir / "graph_net.json"
330-
331- with open (graph_net_json_path , "r" ) as f :
332- metadata = json .load (f )
333-
334- # Add dtype information
335- metadata ["dtype" ] = dtype
336- metadata ["precision" ] = dtype
337- metadata ["generated_from_dtype_generalization" ] = True
338-
339- # Atomic write
340- temp_path = graph_net_json_path .with_suffix (".json.tmp" )
341- with open (temp_path , "w" ) as f :
342- json .dump (metadata , f , indent = 4 )
343- temp_path .replace (graph_net_json_path )
364+ update_json (
365+ graph_net_json_path ,
366+ {
367+ "dtype" : dtype ,
368+ "precision" : dtype ,
369+ "generated_from_dtype_generalization" : True ,
370+ },
371+ )
344372
345373
346374class MultiDtypeFilter :
0 commit comments