Skip to content

Commit 54356f4

Browse files
committed
Resolve merge conflicts: merge dtype and develop branches
1 parent 7da86c6 commit 54356f4

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed
Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
kDimensionGeneralizationPasses = "dimension_generalization_passes"
2-
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
3-
41
import json
52
from pathlib import Path
63
from typing import Union
74

5+
kDimensionGeneralizationPasses = "dimension_generalization_passes"
6+
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
7+
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
8+
9+
10+
def read_json(model_path):
11+
"""
12+
Read JSON from graph_net.json file.
13+
14+
Args:
15+
model_path: Path to model directory
16+
17+
Returns:
18+
Dictionary containing JSON data
19+
"""
20+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
21+
return json.loads(graph_net_json_file_path.read_text())
22+
823

924
def update_json(json_path: Union[str, Path], updates: dict) -> None:
1025
"""
@@ -15,19 +30,35 @@ def update_json(json_path: Union[str, Path], updates: dict) -> None:
1530
updates: Dictionary of key-value pairs to update
1631
"""
1732
json_path = Path(json_path)
18-
33+
1934
# Read existing JSON
2035
if json_path.exists():
2136
with open(json_path, "r") as f:
2237
metadata = json.load(f)
2338
else:
2439
metadata = {}
25-
40+
2641
# Apply updates
2742
metadata.update(updates)
28-
43+
2944
# Atomic write: write to temp file then rename
3045
temp_path = json_path.with_suffix(".json.tmp")
3146
with open(temp_path, "w") as f:
3247
json.dump(metadata, f, indent=4)
3348
temp_path.replace(json_path)
49+
50+
51+
# Backward compatibility: old interface using model_path, field, value
52+
def update_json_legacy(model_path, field, value):
53+
"""
54+
Legacy interface for updating a single field in graph_net.json.
55+
56+
Args:
57+
model_path: Path to model directory
58+
field: Field name to update
59+
value: Value to set
60+
"""
61+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
62+
graph_net_json = json.loads(graph_net_json_file_path.read_text())
63+
graph_net_json[field] = value
64+
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))

samples/timm/resnet18/graph_net.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
"data_type_generalization_passes": [
88
"dtype_generalization_pass_float16",
99
"dtype_generalization_pass_bfloat16"
10-
]
11-
}
10+
],
11+
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
12+
}

0 commit comments

Comments
 (0)