Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
return False


def maybe_get_tosa_collate_path() -> str | None:
"""
Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
path to the where to store the current tests if it is set.
"""
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
if tosa_test_base:
current_test = os.environ.get("PYTEST_CURRENT_TEST")
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
test_class = current_test.split("::")[1]
test_name = current_test.split("::")[-1].split(" ")[0]
if "BI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
elif "MI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
else:
tosa_test_base = os.path.join(tosa_test_base, "other")

return os.path.join(tosa_test_base, test_class, test_name)

return None


def get_tosa_compile_spec(
permute_memory_to_nhwc=True, custom_path=None
) -> list[CompileSpec]:
Expand All @@ -101,7 +124,13 @@ def get_tosa_compile_spec_unbuilt(
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
the compile spec before calling .build() to finalize it.
"""
intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_")
if not custom_path:
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
prefix="arm_tosa_"
)
else:
intermediate_path = custom_path

if not os.path.exists(intermediate_path):
os.makedirs(intermediate_path, exist_ok=True)
compile_spec_builder = (
Expand Down
37 changes: 37 additions & 0 deletions backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import os
import shutil
import tempfile
import unittest

Expand Down Expand Up @@ -149,3 +150,39 @@ def test_dump_ops_and_dtypes(self):
.dump_operator_distribution()
)
# Just test that there are no execeptions.


class TestCollateTosaTests(unittest.TestCase):
"""Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""

def test_collate_tosa_BI_tests(self):
# Set the environment variable to trigger the collation of TOSA tests
os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests"
# Clear out the directory

model = Linear(20, 30)
(
ArmTester(
model,
example_inputs=model.get_inputs(),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.to_edge()
.partition()
.to_executorch()
)
# test that the output directory is created and contains the expected files
assert os.path.exists(
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
)
assert os.path.exists(
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa"
)
assert os.path.exists(
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json"
)

os.environ.pop("TOSA_TESTCASES_BASE_PATH")
shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)
Loading