|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from inputgen.argument.engine import MetaArgEngine |
| 8 | +from inputgen.argument.gen import ArgumentGenerator |
| 9 | +from inputgen.attribute.model import Attribute |
| 10 | +from inputgen.specs.model import Spec |
| 11 | + |
| 12 | + |
| 13 | +def reverse_topological_sort(graph): |
| 14 | + def dfs(node, visited, strack): |
| 15 | + visited[node] = True |
| 16 | + for neig in graph[node]: |
| 17 | + if not visited[neig]: |
| 18 | + dfs(neig, visited, strack) |
| 19 | + stack.append(node) |
| 20 | + |
| 21 | + visited = {node: False for node in graph} |
| 22 | + stack = [] |
| 23 | + |
| 24 | + for node in graph: |
| 25 | + if not visited[node]: |
| 26 | + dfs(node, visited, stack) |
| 27 | + |
| 28 | + return stack |
| 29 | + |
| 30 | + |
| 31 | +def inverse_permutation(permutation): |
| 32 | + n = len(permutation) |
| 33 | + inverse = [0] * n |
| 34 | + for i in range(n): |
| 35 | + inverse[permutation[i]] = i |
| 36 | + return inverse |
| 37 | + |
| 38 | + |
| 39 | +class MetaArgTupleEngine: |
| 40 | + def __init__(self, spec: Spec, out: bool = False): |
| 41 | + if out: |
| 42 | + raise NotImplementedError("out=True is not supported yet") |
| 43 | + self.args = spec.inspec |
| 44 | + self.order = self._sort_dependencies() |
| 45 | + self.order_inverse_perm = inverse_permutation(self.order) |
| 46 | + |
| 47 | + def _generate_dependency_dag(self): |
| 48 | + graph = {} |
| 49 | + for i, arg in enumerate(self.args): |
| 50 | + if arg.deps is None: |
| 51 | + graph[i] = [] |
| 52 | + else: |
| 53 | + graph[i] = arg.deps |
| 54 | + return graph |
| 55 | + |
| 56 | + def _sort_dependencies(self): |
| 57 | + graph = self._generate_dependency_dag() |
| 58 | + return reverse_topological_sort(graph) |
| 59 | + |
| 60 | + def _sort_meta_tuple(self, meta_tuple): |
| 61 | + return tuple( |
| 62 | + meta_tuple[self.order_inverse_perm[i]] for i in range(len(self.args)) |
| 63 | + ) |
| 64 | + |
| 65 | + def _get_deps(self, meta_tuple, arg_deps): |
| 66 | + value_tuple = tuple(ArgumentGenerator(m).gen() for m in meta_tuple) |
| 67 | + return tuple(value_tuple[self.order_inverse_perm[ix]] for ix in arg_deps) |
| 68 | + |
| 69 | + def gen_meta_tuples(self, valid: bool, focus_ix: int): |
| 70 | + tuples = [()] |
| 71 | + for ix in self.order: |
| 72 | + arg = self.args[ix] |
| 73 | + new_tuples = [] |
| 74 | + focuses = [None] |
| 75 | + if ix == focus_ix: |
| 76 | + focuses = Attribute.hierarchy(arg.type) |
| 77 | + for focus in focuses: |
| 78 | + for meta_tuple in tuples: |
| 79 | + deps = self._get_deps(meta_tuple, arg.deps) |
| 80 | + engine = MetaArgEngine(arg.type, arg.constraints, deps, valid) |
| 81 | + for meta_arg in engine.gen(focus): |
| 82 | + new_tuples.append(meta_tuple + (meta_arg,)) |
| 83 | + tuples = new_tuples |
| 84 | + return map(self._sort_meta_tuple, tuples) |
| 85 | + |
| 86 | + def gen_valid_meta_tuples(self): |
| 87 | + valid_tuples = [] |
| 88 | + for ix in range(len(self.args)): |
| 89 | + valid_tuples += self.gen_meta_tuples(True, ix) |
| 90 | + return valid_tuples |
| 91 | + |
| 92 | + def gen_invalid_from_valid(self, valid_tuple): |
| 93 | + # Valid [str(x) for x in valid_tuple] |
| 94 | + valid_value_tuple = tuple(ArgumentGenerator(m).gen() for m in valid_tuple) |
| 95 | + invalid_tuples = [] |
| 96 | + for ix in range(len(self.args)): |
| 97 | + arg = self.args[ix] |
| 98 | + # Generating invalid argument {ix} {arg.type} |
| 99 | + deps = tuple(valid_value_tuple[i] for i in arg.deps) |
| 100 | + for focus in Attribute.hierarchy(arg.type): |
| 101 | + engine = MetaArgEngine(arg.type, arg.constraints, deps, False) |
| 102 | + for meta_arg in engine.gen(focus): |
| 103 | + invalid_tuple = ( |
| 104 | + valid_tuple[:ix] + (meta_arg,) + valid_tuple[ix + 1 :] |
| 105 | + ) |
| 106 | + # Invalid {ix} {focus} [str(x) for x in invalid_tuple] |
| 107 | + invalid_tuples.append(invalid_tuple) |
| 108 | + invalid_tuples = list(set(invalid_tuples)) |
| 109 | + return invalid_tuples |
| 110 | + |
| 111 | + def gen_invalid_meta_tuples(self): |
| 112 | + valid_tuples = self.gen_valid_meta_tuples() |
| 113 | + invalid_tuples = [] |
| 114 | + for valid_tuple in valid_tuples: |
| 115 | + invalids = self.gen_invalid_from_valid(valid_tuple) |
| 116 | + invalid_tuples += invalids |
| 117 | + invalid_tuples = list(set(invalid_tuples)) |
| 118 | + return invalid_tuples |
| 119 | + |
| 120 | + def gen(self, valid: bool = True): |
| 121 | + if valid: |
| 122 | + return self.gen_valid_meta_tuples() |
| 123 | + else: |
| 124 | + return self.gen_invalid_meta_tuples() |
0 commit comments