1818ALL_EDGE_OPS = SAMPLE_INPUT .keys () | CUSTOM_EDGE_OPS
1919
2020# Add all targets and TOSA profiles we support here.
21- TARGETS = { "tosa_BI " , "tosa_MI " , "u55_BI" , "u85_BI" }
21+ TARGETS = [ "tosa_MI " , "tosa_BI " , "u55_BI" , "u85_BI" ]
2222
2323
24- def get_edge_ops ():
24+ def get_op_name_map ():
2525 """
26- Returns a set with edge_ops with names on the form to be used in unittests:
26+ Returns a mapping from names on the form to be used in unittests to edge op :
2727 1. Names are in lowercase.
28- 2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
28+ 2. Overload is ignored if 'default', otherwise it's appended with an underscore.
2929 3. Overly verbose name are shortened by removing certain prefixes/suffixes.
3030
3131 Examples:
3232 abs.default -> abs
3333 split_copy.Tensor -> split_tensor
3434 """
35- edge_ops = set ()
35+ map = {}
3636 for edge_name in ALL_EDGE_OPS :
3737 op , overload = edge_name .split ("." )
3838
@@ -45,21 +45,24 @@ def get_edge_ops():
4545 overload = overload .lower ()
4646
4747 if overload == "default" :
48- edge_ops . add ( op )
48+ map [ op ] = edge_name
4949 else :
50- edge_ops . add ( f"{ op } _{ overload } " )
50+ map [ f"{ op } _{ overload } " ] = edge_name
5151
52- return edge_ops
52+ return map
5353
5454
55- def parse_test_name (test_name : str , edge_ops : set [str ]) -> tuple [str , str , bool ]:
55+ def parse_test_name (
56+ test_name : str , op_name_map : dict [str , str ]
57+ ) -> tuple [str , str , bool ]:
5658 """
5759 Parses a test name on the form
5860 test_OP_TARGET_<not_delegated>_<any_other_info>
59- where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
60- The "not_delegated" suffix indicates that the test tests that the op is not delegated.
61+ where OP must match a key in op_name_map and TARGET one string in TARGETS. The
62+ "not_delegated" suffix indicates that the test tests that the op is not delegated.
6163
62- Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
64+ Examples of valid names: "test_mm_u55_BI_not_delegated" and
65+ "test_add_scalar_tosa_MI_two_inputs".
6366
6467 Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
6568 """
@@ -83,7 +86,7 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
8386
8487 assert target != "None" , f"{ test_name } does not contain one of { TARGETS } "
8588 assert (
86- op in edge_ops
89+ op in op_name_map . keys ()
8790 ), f"Parsed unvalid OP from { test_name } , { op } does not exist in edge.yaml or CUSTOM_EDGE_OPS"
8891
8992 return op , target , is_delegated
@@ -95,13 +98,13 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
9598
9699 sys .tracebacklimit = 0 # Do not print stack trace
97100
98- edge_ops = get_edge_ops ()
101+ op_name_map = get_op_name_map ()
99102 exit_code = 0
100103
101104 for test_name in sys .argv [1 :]:
102105 try :
103106 assert test_name [:5 ] == "test_" , f"Unexpected input: { test_name } "
104- parse_test_name (test_name , edge_ops )
107+ parse_test_name (test_name , op_name_map )
105108 except AssertionError as e :
106109 print (e )
107110 exit_code = 1
0 commit comments