11import json
22from pathlib import Path
3- from typing import Union
43
54kDimensionGeneralizationPasses = "dimension_generalization_passes"
65kDataTypeGeneralizationPasses = "data_type_generalization_passes"
76kSymbolicDimensionReifier = "symbolic_dimension_reifier"
87
8+ # Fields for dtype generalization metadata
9+ kDtypeGeneralizationTargetDtype = "dtype_generalization_target_dtype"
10+ kDtypeGeneralizationPrecision = "dtype_generalization_precision"
11+ kDtypeGeneralizationGenerated = "dtype_generalization_generated"
12+
913
1014def read_json (model_path ):
1115 """
@@ -21,44 +25,37 @@ def read_json(model_path):
2125 return json .loads (graph_net_json_file_path .read_text ())
2226
2327
24- def update_json (json_path : Union [ str , Path ], updates : dict ) -> None :
28+ def update_json (model_path , field , value ) :
2529 """
26- Atomically update a JSON file with the given updates .
30+ Update a single field in graph_net.json .
2731
2832 Args:
29- json_path: Path to the JSON file
30- updates: Dictionary of key-value pairs to update
33+ model_path: Path to model directory or graph_net.json file
34+ field: Field name to update
35+ value: Value to set
3136 """
32- json_path = Path (json_path )
37+ if isinstance (model_path , (str , Path )):
38+ model_path = Path (model_path )
39+ # If it's a file path, use it directly; otherwise assume it's a directory
40+ if model_path .suffix == ".json" :
41+ graph_net_json_file_path = model_path
42+ else :
43+ graph_net_json_file_path = model_path / "graph_net.json"
44+ else :
45+ graph_net_json_file_path = Path (f"{ model_path } /graph_net.json" )
3346
3447 # Read existing JSON
35- if json_path .exists ():
36- with open (json_path , "r" ) as f :
37- metadata = json .load (f )
48+ if graph_net_json_file_path .exists ():
49+ with open (graph_net_json_file_path , "r" ) as f :
50+ graph_net_json = json .load (f )
3851 else :
39- metadata = {}
52+ graph_net_json = {}
4053
41- # Apply updates
42- metadata . update ( updates )
54+ # Update field
55+ graph_net_json [ field ] = value
4356
4457 # Atomic write: write to temp file then rename
45- temp_path = json_path .with_suffix (".json.tmp" )
58+ temp_path = graph_net_json_file_path .with_suffix (".json.tmp" )
4659 with open (temp_path , "w" ) as f :
47- json .dump (metadata , f , indent = 4 )
48- temp_path .replace (json_path )
49-
50-
51- # Backward compatibility: old interface using model_path, field, value
52- def update_json_legacy (model_path , field , value ):
53- """
54- Legacy interface for updating a single field in graph_net.json.
55-
56- Args:
57- model_path: Path to model directory
58- field: Field name to update
59- value: Value to set
60- """
61- graph_net_json_file_path = Path (f"{ model_path } /graph_net.json" )
62- graph_net_json = json .loads (graph_net_json_file_path .read_text ())
63- graph_net_json [field ] = value
64- graph_net_json_file_path .write_text (json .dumps (graph_net_json , indent = 4 ))
60+ json .dump (graph_net_json , f , indent = 4 )
61+ temp_path .replace (graph_net_json_file_path )
0 commit comments