Skip to content

Commit 90afcc9

Browse files
committed
Refactor: fix update_json API and add dtype generalization field constants
1 parent 34b300e commit 90afcc9

File tree

2 files changed

+49
-45
lines changed

2 files changed

+49
-45
lines changed
Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import json
22
from pathlib import Path
3-
from typing import Union
43

54
kDimensionGeneralizationPasses = "dimension_generalization_passes"
65
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
76
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
87

8+
# Fields for dtype generalization metadata
9+
kDtypeGeneralizationTargetDtype = "dtype_generalization_target_dtype"
10+
kDtypeGeneralizationPrecision = "dtype_generalization_precision"
11+
kDtypeGeneralizationGenerated = "dtype_generalization_generated"
12+
913

1014
def read_json(model_path):
1115
"""
@@ -21,44 +25,37 @@ def read_json(model_path):
2125
return json.loads(graph_net_json_file_path.read_text())
2226

2327

24-
def update_json(json_path: Union[str, Path], updates: dict) -> None:
28+
def update_json(model_path, field, value):
2529
"""
26-
Atomically update a JSON file with the given updates.
30+
Update a single field in graph_net.json.
2731
2832
Args:
29-
json_path: Path to the JSON file
30-
updates: Dictionary of key-value pairs to update
33+
model_path: Path to model directory or graph_net.json file
34+
field: Field name to update
35+
value: Value to set
3136
"""
32-
json_path = Path(json_path)
37+
if isinstance(model_path, (str, Path)):
38+
model_path = Path(model_path)
39+
# If it's a file path, use it directly; otherwise assume it's a directory
40+
if model_path.suffix == ".json":
41+
graph_net_json_file_path = model_path
42+
else:
43+
graph_net_json_file_path = model_path / "graph_net.json"
44+
else:
45+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
3346

3447
# Read existing JSON
35-
if json_path.exists():
36-
with open(json_path, "r") as f:
37-
metadata = json.load(f)
48+
if graph_net_json_file_path.exists():
49+
with open(graph_net_json_file_path, "r") as f:
50+
graph_net_json = json.load(f)
3851
else:
39-
metadata = {}
52+
graph_net_json = {}
4053

41-
# Apply updates
42-
metadata.update(updates)
54+
# Update field
55+
graph_net_json[field] = value
4356

4457
# Atomic write: write to temp file then rename
45-
temp_path = json_path.with_suffix(".json.tmp")
58+
temp_path = graph_net_json_file_path.with_suffix(".json.tmp")
4659
with open(temp_path, "w") as f:
47-
json.dump(metadata, f, indent=4)
48-
temp_path.replace(json_path)
49-
50-
51-
# Backward compatibility: old interface using model_path, field, value
52-
def update_json_legacy(model_path, field, value):
53-
"""
54-
Legacy interface for updating a single field in graph_net.json.
55-
56-
Args:
57-
model_path: Path to model directory
58-
field: Field name to update
59-
value: Value to set
60-
"""
61-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
62-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
63-
graph_net_json[field] = value
64-
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))
60+
json.dump(graph_net_json, f, indent=4)
61+
temp_path.replace(graph_net_json_file_path)

graph_net/torch/dtype_generalizer.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
from graph_net.graph_net_json_file_util import (
2323
kDataTypeGeneralizationPasses,
24+
kDtypeGeneralizationTargetDtype,
25+
kDtypeGeneralizationPrecision,
26+
kDtypeGeneralizationGenerated,
2427
update_json,
2528
)
2629
from graph_net.torch.constraint_util import RunModelPredicator
@@ -185,10 +188,7 @@ def _save_dtype_pass_names(
185188
dtype_pass_names: List of working pass names
186189
model_path: Path to model directory
187190
"""
188-
graph_net_json_path = Path(model_path) / "graph_net.json"
189-
update_json(
190-
graph_net_json_path, {kDataTypeGeneralizationPasses: dtype_pass_names}
191-
)
191+
update_json(model_path, kDataTypeGeneralizationPasses, dtype_pass_names)
192192

193193

194194
class ApplyDataTypeGeneralizationPasses:
@@ -308,13 +308,25 @@ def _apply_pass_and_generate(
308308
Path to the generated sample directory
309309
"""
310310
# Parse pass name to extract base name and dtype
311-
# Format: "dtype_generalization_pass_float16"
311+
# Format: "dtype_generalization_pass_float16" or "dtype_generalization_pass_bfloat16"
312+
# The base name "dtype_generalization_pass" corresponds to the file
313+
# dtype_generalization_pass.py, which contains the ConcretePass class.
312314
parts = pass_name.rsplit("_", 1)
313315
if len(parts) != 2:
314-
raise ValueError(f"Invalid pass name format: {pass_name}")
316+
raise ValueError(
317+
f"Invalid pass name format: {pass_name}. "
318+
f"Expected format: 'dtype_generalization_pass_<dtype>'"
319+
)
315320

316321
base_name, dtype = parts
317322

323+
# Validate base name
324+
if base_name != "dtype_generalization_pass":
325+
raise ValueError(
326+
f"Unknown pass base name: {base_name}. "
327+
f"Expected: 'dtype_generalization_pass'"
328+
)
329+
318330
# Load and apply the pass
319331
dtype_pass_class = get_dtype_generalization_pass(base_name)
320332
dtype_pass = dtype_pass_class(
@@ -363,14 +375,9 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
363375
dtype: Target dtype
364376
"""
365377
graph_net_json_path = sample_dir / "graph_net.json"
366-
update_json(
367-
graph_net_json_path,
368-
{
369-
"dtype": dtype,
370-
"precision": dtype,
371-
"generated_from_dtype_generalization": True,
372-
},
373-
)
378+
update_json(graph_net_json_path, kDtypeGeneralizationTargetDtype, dtype)
379+
update_json(graph_net_json_path, kDtypeGeneralizationPrecision, dtype)
380+
update_json(graph_net_json_path, kDtypeGeneralizationGenerated, True)
374381

375382

376383
class MultiDtypeFilter:

0 commit comments

Comments
 (0)