Skip to content

Commit 2ace973

Browse files
committed
Merge branch 'develop' into opt_saved_results
2 parents fdc1994 + 417240e commit 2ace973

File tree

65 files changed

+2978
-722
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2978
-722
lines changed

graph_net/analysis_util.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
from scipy.stats import gmean
55
from graph_net.config.datatype_tolerance_config import get_precision
6+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
7+
from graph_net.verify_aggregated_params import determine_tolerances
68

79

810
def detect_sample_status(log_text: str) -> str:
@@ -293,38 +295,24 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
293295
return False
294296

295297

296-
def fake_perf_degrad(tolerance, error_code, type="default") -> str:
298+
def fake_perf_degrad(
299+
tolerance,
300+
error_code,
301+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
302+
) -> str:
297303
"""
298304
Judge current correctness based on tolerance t and status.
305+
Refactored to delegate logic to PositiveToleranceInterpretation classes.
299306
"""
300-
if type == "default":
301-
if tolerance >= 3:
302-
return "correct"
303-
elif error_code == "accuracy" and tolerance >= 1:
304-
return "correct"
305-
else:
306-
return error_code
307-
elif type == "extended":
308-
if (
309-
error_code == "compile_fail" or error_code == "runtime_fail"
310-
) and tolerance >= 4:
311-
return "correct"
312-
elif error_code == "eager_fail" and tolerance >= 3:
313-
return "correct"
314-
elif (
315-
error_code == "shape_mismatch" or error_code == "type_mismatch"
316-
) and tolerance >= 2:
317-
return "correct"
318-
elif error_code == "accuracy" and tolerance >= 1:
319-
return "correct"
320-
else:
321-
return error_code
322-
else:
323-
raise NotImplementedError
307+
if positive_tolerance_interpretation.is_error_tolerated(tolerance, error_code):
308+
return "correct"
309+
310+
return error_code
324311

325312

326313
def calculate_scores(
327314
samples: list,
315+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
328316
p: float = 0,
329317
b: float = 0.1,
330318
type: str = "ESt",
@@ -339,7 +327,10 @@ def calculate_scores(
339327

340328
scores = {}
341329

342-
for tolerance in range(-10, 5):
330+
strategy = positive_tolerance_interpretation
331+
tolerances = determine_tolerances(samples, positive_tolerance_interpretation)
332+
333+
for tolerance in tolerances:
343334
rectified_speedups = []
344335
rectified_speedups_fake_degrad = []
345336

@@ -373,12 +364,10 @@ def calculate_scores(
373364
)
374365
else:
375366
if not is_correct_at_t1[idx]:
376-
current_correctness = fake_perf_degrad(
367+
is_tolerated = strategy.is_error_tolerated(
377368
tolerance, fail_type_at_t1[idx]
378369
)
379-
rec_speedup_fake_degrad = (
380-
1 if current_correctness == "correct" else b
381-
)
370+
rec_speedup_fake_degrad = 1 if is_tolerated else b
382371
else:
383372
rec_speedup_fake_degrad = (
384373
speedup_at_t1[idx] ** (p + 1)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.model_path_handler \
6+
--model-path-list "customize_your_model_path_list" \
7+
--handler-config $(base64 -w 0 <<EOF
8+
{
9+
"handler_path": "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py",
10+
"handler_class_name": "customize_your_class_name",
11+
"handler_config": {
12+
"resume": true,
13+
"model_path_prefix": "/customize_your_model_path_prefix",
14+
"output_dir": "/customize_your_output_file"
15+
}
16+
}
17+
EOF
18+
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
samples/timm/crossvit_small_240.in1k
2+
samples/timm/poolformerv2_s12.sail_in1k
3+
samples/timm/regnety_080.pycls_in1k
4+
samples/timm/dla46x_c.in1k
5+
samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
6+
samples/timm/efficientnetv2_rw_s.ra2_in1k
7+
samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
8+
samples/timm/fastvit_t8.apple_dist_in1k
9+
samples/timm/test_byobnet.r160_in1k
10+
samples/timm/mambaout_base.in1k
11+
samples/timm/davit_small
12+
samples/timm/resnet61q.ra2_in1k
13+
samples/timm/coat_tiny
14+
samples/timm/regnetx_004.pycls_in1k
15+
samples/timm/convnextv2_large.fcmae
16+
samples/timm/regnety_640.seer
17+
samples/timm/repvit_m1_1.dist_300e_in1k
18+
samples/timm/tinynet_d.in1k
19+
samples/timm/resnetrs270.tf_in1k
20+
samples/timm/cait_m48_448
21+
samples/timm/legacy_seresnet50.in1k
22+
samples/timm/tinynet_a.in1k
23+
samples/timm/convnext_small.fb_in1k
24+
samples/timm/vit_huge_patch14_clip_quickgelu_224.dfn5b
25+
samples/timm/dpn131.mx_in1k
26+
samples/timm/convnextv2_large.fcmae_ft_in1k
27+
samples/timm/convnextv2_small
28+
samples/timm/repvit_m1.dist_in1k
29+
samples/timm/cs3darknet_s
30+
samples/timm/resnet50d.a1_in1k
31+
samples/timm/dm_nfnet_f6
32+
samples/timm/coatnet_1_rw_224
33+
samples/timm/lcnet_050.ra2_in1k
34+
samples/timm/efficientnet_em.ra2_in1k
35+
samples/timm/dpn48b
36+
samples/timm/semnasnet_075.rmsp_in1k
37+
samples/timm/skresnet34.ra_in1k
38+
samples/timm/crossvit_15_dagger_240.in1k
39+
samples/timm/mnasnet_100.rmsp_in1k
40+
samples/timm/mobilenetv3_rw.rmsp_in1k
41+
samples/timm/xception65p.ra3_in1k
42+
samples/timm/coatnet_0_rw_224
43+
samples/timm/eca_nfnet_l3
44+
samples/timm/deit3_base_patch16_224.fb_in1k
45+
samples/timm/mambaout_base_short_rw.sw_e500_in1k
46+
samples/timm/mobilenetv4_conv_small.e1200_r224_in1k
47+
samples/timm/xception71.tf_in1k
48+
samples/timm/dla60.in1k
49+
samples/timm/repghostnet_130.in1k
50+
samples/timm/mambaout_base_plus_rw.sw_e150_in12k
51+
samples/timm/poolformerv2_s36.sail_in1k
52+
samples/timm/deit3_huge_patch14_224.fb_in1k
53+
samples/timm/vit_base_patch32_clip_224.datacompxl
54+
samples/timm/poolformer_m48.sail_in1k
55+
samples/timm/regnety_006.pycls_in1k
56+
samples/timm/starnet_s4.in1k
57+
samples/timm/poolformer_m36.sail_in1k
58+
samples/timm/vit_huge_patch14_gap_224.in1k_ijepa
59+
samples/timm/efficientnet_b3.ra2_in1k
60+
samples/timm/mobilenetv3_large_150d.ra4_e3600_r256_in1k
61+
samples/timm/hgnetv2_b0.ssld_stage1_in22k_in1k
62+
samples/timm/convnextv2_huge.fcmae
63+
samples/timm/davit_huge
64+
samples/timm/regnetx_004_tv.tv2_in1k
65+
samples/timm/dla34.in1k
66+
samples/timm/convnext_xlarge.fb_in22k
67+
samples/timm/resmlp_12_224.fb_dino
68+
samples/timm/fasternet_t1.in1k
69+
samples/timm/resnetblur50.bt_in1k
70+
samples/timm/res2net50d.in1k
71+
samples/timm/vit_base_patch32_224.augreg_in1k
72+
samples/timm/mambaout_base_wide_rw.sw_e500_in1k
73+
samples/timm/vgg19_bn.tv_in1k
74+
samples/timm/vit_small_patch16_rope_ape_224.naver_in1k
75+
samples/timm/hardcorenas_b.miil_green_in1k
76+
samples/timm/vgg16.tv_in1k
77+
samples/timm/xception41p.ra3_in1k
78+
samples/timm/efficientnet_lite0.ra_in1k
79+
samples/timm/regnetv_064.ra3_in1k
80+
samples/timm/regnety_320.pycls_in1k
81+
samples/timm/convnext_pico.d1_in1k
82+
samples/timm/repvit_m1_0.dist_300e_in1k
83+
samples/timm/resnet50c.gluon_in1k
84+
samples/timm/mobileone_s4.apple_in1k
85+
samples/timm/ghostnet_100.in1k
86+
samples/timm/deit_base_distilled_patch16_384
87+
samples/timm/dpn68b.mx_in1k
88+
samples/timm/dla60_res2next
89+
samples/timm/resnet101d.gluon_in1k
90+
samples/timm/eva02_large_patch14_clip_224.merged2b
91+
samples/timm/fasternet_m.in1k
92+
samples/timm/mobilenetv2_110d.ra_in1k
93+
samples/timm/regnetx_064.pycls_in1k
94+
samples/timm/cspresnet50.ra_in1k
95+
samples/timm/resmlp_24_224.fb_dino
96+
samples/timm/mobileone_s3.apple_in1k
97+
samples/timm/mobileone_s2.apple_in1k
98+
samples/timm/res2net101d
99+
samples/timm/hardcorenas_f.miil_green_in1k
100+
samples/timm/hrnet_w18_ssld.paddle_in1k
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#samples/timm/crossvit_small_240.in1k
2+
#samples/timm/poolformerv2_s12.sail_in1k
3+
#samples/timm/regnety_080.pycls_in1k
4+
#samples/timm/dla46x_c.in1k
5+
#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k
6+
samples/timm/efficientnetv2_rw_s.ra2_in1k
7+
samples/timm/vit_base_patch16_rope_ape_224.naver_in1k
8+
#samples/timm/fastvit_t8.apple_dist_in1k
9+
#samples/timm/test_byobnet.r160_in1k
10+
#samples/timm/mambaout_base.in1k
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from enum import IntEnum
2+
3+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
4+
5+
6+
class DefaultErrorEnum(IntEnum):
7+
"""
8+
Values correspond to the minimum tolerance level required.
9+
"""
10+
11+
kAccuracyViolation = 1 # Accuracy
12+
kRuntimeFailure = 2 # Includes Runtime, NaN, Inf, TypeMismatch, etc.
13+
kCompilationFailed = 3 # Compile Failure
14+
15+
@classmethod
16+
def get_error_enum(cls, base_error_type: str) -> "DefaultErrorEnum":
17+
if not base_error_type:
18+
return cls.kRuntimeFailure
19+
20+
etype = base_error_type.lower()
21+
22+
if "accuracy" in etype:
23+
return cls.kAccuracyViolation
24+
25+
if "compile_fail" in etype:
26+
return cls.kCompilationFailed
27+
28+
return cls.kRuntimeFailure
29+
30+
31+
class DefaultPositiveToleranceInterpretation(PositiveToleranceInterpretation):
32+
"""
33+
Legacy interpretation:
34+
- t=1: Accuracy errors tolerated.
35+
- t=3: Runtime/Compilation errors tolerated.
36+
"""
37+
38+
def __init__(self, *argc, **kwargs):
39+
super().__init__(*argc, **kwargs)
40+
41+
def type_name(self) -> str:
42+
return "default"
43+
44+
def get_errno(self, error_type: str) -> int:
45+
return DefaultErrorEnum.get_error_enum(error_type).value
46+
47+
def get_error_type(self, errno: int) -> str:
48+
mapping = {1: "accuracy", 2: "runtime_fail", 3: "compile_fail"}
49+
return mapping.get(errno, "unknown_error")
50+
51+
def get_tolerance_mapping(self) -> dict[int, int]:
52+
return {
53+
DefaultErrorEnum.kAccuracyViolation.value: 1,
54+
DefaultErrorEnum.kRuntimeFailure.value: 3,
55+
DefaultErrorEnum.kCompilationFailed.value: 3,
56+
}
57+
58+
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
59+
if base_error_code == "correct":
60+
return True
61+
if base_error_code in ["eager_fail", "reference_fail"]:
62+
return False
63+
64+
error_enum = DefaultErrorEnum.get_error_enum(base_error_code)
65+
mapping = self.get_tolerance_mapping()
66+
required_threshold = mapping.get(error_enum.value, 999)
67+
68+
return tolerance >= required_threshold
69+
70+
def num_errno_enum_values(self) -> int:
71+
"""
72+
Default mode defines 3 levels of errors:
73+
1: Accuracy
74+
2: Runtime (Generic)
75+
3: Compilation
76+
"""
77+
return len(DefaultErrorEnum)
Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,61 @@
1-
from pathlib import Path
21
import json
2+
from pathlib import Path
33

44
kDimensionGeneralizationPasses = "dimension_generalization_passes"
5+
kDataTypeGeneralizationPasses = "data_type_generalization_passes"
56
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
67

8+
# Fields for dtype generalization metadata
9+
kDtypeGeneralizationTargetDtype = "dtype_generalization_target_dtype"
10+
kDtypeGeneralizationPrecision = "dtype_generalization_precision"
11+
kDtypeGeneralizationGenerated = "dtype_generalization_generated"
12+
713

814
def read_json(model_path):
15+
"""
16+
Read JSON from graph_net.json file.
17+
18+
Args:
19+
model_path: Path to model directory
20+
21+
Returns:
22+
Dictionary containing JSON data
23+
"""
924
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
1025
return json.loads(graph_net_json_file_path.read_text())
1126

1227

1328
def update_json(model_path, field, value):
14-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
15-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
29+
"""
30+
Update a single field in graph_net.json.
31+
32+
Args:
33+
model_path: Path to model directory or graph_net.json file
34+
field: Field name to update
35+
value: Value to set
36+
"""
37+
if isinstance(model_path, (str, Path)):
38+
model_path = Path(model_path)
39+
# If it's a file path, use it directly; otherwise assume it's a directory
40+
if model_path.suffix == ".json":
41+
graph_net_json_file_path = model_path
42+
else:
43+
graph_net_json_file_path = model_path / "graph_net.json"
44+
else:
45+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
46+
47+
# Read existing JSON
48+
if graph_net_json_file_path.exists():
49+
with open(graph_net_json_file_path, "r") as f:
50+
graph_net_json = json.load(f)
51+
else:
52+
graph_net_json = {}
53+
54+
# Update field
1655
graph_net_json[field] = value
17-
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))
56+
57+
# Atomic write: write to temp file then rename
58+
temp_path = graph_net_json_file_path.with_suffix(".json.tmp")
59+
with open(temp_path, "w") as f:
60+
json.dump(graph_net_json, f, indent=4)
61+
temp_path.replace(graph_net_json_file_path)

0 commit comments

Comments
 (0)