Skip to content

Commit 38cacb0

Browse files
committed
Provide optional bundled io atol/rtol as parameters
Change-Id: I80148fffc4b25e5470b53b559e184ac08630426a
1 parent 69b8da5 commit 38cacb0

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed

backends/arm/scripts/run_vkml.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly ins
1919

2020

2121
model=""
22+
opt_flags=""
2223
build_path="cmake-out-vkml"
2324
converter="model-converter"
2425

@@ -33,6 +34,7 @@ help() {
3334
for arg in "$@"; do
3435
case $arg in
3536
-h|--help) help ;;
37+
--optional_flags=*) opt_flags="${arg#*=}";;
3638
--model=*) model="${arg#*=}";;
3739
--build_path=*) build_path="${arg#*=}";;
3840
*)
@@ -59,7 +61,7 @@ runner=$(find ${build_path} -name executor_runner -type f)
5961

6062

6163
echo "--------------------------------------------------------------------------------"
62-
echo "Running ${model} with ${runner}"
64+
echo "Running ${model} with ${runner} ${opt_flags}"
6365
echo "WARNING: The VK_ML layer driver will not provide accurate performance information"
6466
echo "--------------------------------------------------------------------------------"
6567

@@ -75,7 +77,7 @@ fi
7577
log_file=$(mktemp)
7678

7779

78-
${nobuf} ${runner} -model_path ${model} | tee ${log_file}
80+
${nobuf} ${runner} -model_path ${model} ${opt_flags} | tee ${log_file}
7981
echo "[${BASH_SOURCE[0]}] execution complete, $?"
8082

8183
# Most of these can happen for bare metal or linx executor_runner runs.

backends/arm/test/test_arm_baremetal.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ test_models_vkml() { # End to End model tests using model_test.py
254254

255255
# VKML
256256
echo "${TEST_SUITE_NAME}: Test target VKML"
257-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet18 --extra_flags="-DET_BUNDLE_IO_ATOL=0.2 -DET_BUNDLE_IO_RTOL=0.2"
258-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet50 --extra_flags="-DET_BUNDLE_IO_ATOL=0.2 -DET_BUNDLE_IO_RTOL=0.2"
257+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet18 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2"
258+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet50 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2"
259259

260260
echo "${TEST_SUITE_NAME}: PASS"
261261
}

backends/arm/test/test_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,15 @@ def get_args():
6767
parser.add_argument(
6868
"--extra_flags",
6969
required=False,
70-
default=None,
70+
default="",
7171
help="Extra cmake flags to pass the when building the executor_runner",
7272
)
73+
parser.add_argument(
74+
"--extra_runtime_flags",
75+
required=False,
76+
default="",
77+
help="Extra runtime flags to pass the final runner/executable",
78+
)
7379
parser.add_argument(
7480
"--timeout",
7581
required=False,
@@ -228,13 +234,14 @@ def build_vkml_runtime(
228234
return runner
229235

230236

231-
def run_vkml(script_path: str, pte_file: str, runner_build_path: str):
237+
def run_vkml(script_path: str, pte_file: str, runner_build_path: str, extra_flags: str):
232238
run_external_cmd(
233239
[
234240
"bash",
235241
os.path.join(script_path, "run_vkml.sh"),
236242
f"--model={pte_file}",
237243
f"--build_path={runner_build_path}",
244+
f"--optional_flags={extra_flags}",
238245
]
239246
)
240247

@@ -297,7 +304,7 @@ def run_vkml(script_path: str, pte_file: str, runner_build_path: str):
297304
)
298305

299306
start_time = time.perf_counter()
300-
run_vkml(script_path, pte_file, build_path)
307+
run_vkml(script_path, pte_file, build_path, args.extra_runtime_flags)
301308
end_time = time.perf_counter()
302309
print(
303310
f"[Test model: {end_time - start_time:.2f} s] Tested VKML runner: {vkml_runner}"

examples/portable/executor_runner/executor_runner.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ DEFINE_int32(
8383
-1,
8484
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
8585

86+
#ifdef ET_BUNDLE_IO_ENABLED
87+
DEFINE_double(bundleio_rtol, 0.01, "Relative tolerance for bundled IO.");
88+
DEFINE_double(bundleio_atol, 0.01, "Absolute tolerance for bundled IO.");
89+
#endif
90+
8691
using executorch::aten::ScalarType;
8792
using executorch::aten::Tensor;
8893
#ifdef ET_BUNDLE_IO_ENABLED
@@ -108,19 +113,6 @@ using executorch::runtime::Span;
108113
using executorch::runtime::Tag;
109114
using executorch::runtime::TensorInfo;
110115

111-
#ifdef ET_BUNDLE_IO_ENABLED
112-
#if defined(ET_BUNDLE_IO_ATOL)
113-
constexpr float bundleio_atol = ET_BUNDLE_IO_ATOL;
114-
#else
115-
constexpr float bundleio_atol = 0.01;
116-
#endif
117-
#if defined(ET_BUNDLE_IO_RTOL)
118-
constexpr float bundleio_rtol = ET_BUNDLE_IO_RTOL;
119-
#else
120-
constexpr float bundleio_rtol = 0.01;
121-
#endif
122-
#endif
123-
124116
/// Helper to manage resources for ETDump generation
125117
class EventTraceManager {
126118
public:
@@ -604,7 +596,11 @@ int main(int argc, char** argv) {
604596
}
605597

606598
Error status = verify_method_outputs(
607-
*method, model_pte, testset_idx, bundleio_rtol, bundleio_atol);
599+
*method,
600+
model_pte,
601+
testset_idx,
602+
FLAGS_bundleio_rtol,
603+
FLAGS_bundleio_atol);
608604
if (status == Error::Ok) {
609605
ET_LOG(Info, "Model output match expected BundleIO bpte ref data.");
610606
ET_LOG(Info, "TEST: BundleIO index[%zu] Test_result: PASS", testset_idx);
@@ -613,8 +609,8 @@ int main(int argc, char** argv) {
613609
ET_LOG(
614610
Error,
615611
"Model output don't match expected BundleIO bpte ref data. rtol=%f atol=%f",
616-
bundleio_rtol,
617-
bundleio_atol);
612+
FLAGS_bundleio_rtol,
613+
FLAGS_bundleio_atol);
618614
ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx);
619615
ET_LOG(
620616
Error,

0 commit comments

Comments
 (0)