Skip to content

Commit 6d47586

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Introduce OpModelGenerator (#42)
Summary: Pull Request resolved: #42 imported-using-ghimport Test Plan: Imported from OSS Rollback Plan: Reviewed By: digantdesai Differential Revision: D80468314 Pulled By: manuelcandales fbshipit-source-id: c0713bec905645b832f2a73b9fe022ee1d0330dc
1 parent da927e1 commit 6d47586

File tree

9 files changed

+791
-6
lines changed

9 files changed

+791
-6
lines changed

examples/minimal_modelgen_example.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from facto.inputgen.utils.config import TensorConfig
8+
from facto.modelgen.gen import OpModelGenerator
9+
from facto.specdb.db import SpecDictDB
10+
from facto.utils.ops import get_op_overload
11+
12+
13+
def main():
14+
op_name = "add.Tensor"
15+
spec = SpecDictDB[op_name]
16+
op = get_op_overload(op_name)
17+
config = TensorConfig(device="cpu", half_precision=False)
18+
for model, args, kwargs in OpModelGenerator(op, spec, config).gen(verbose=True):
19+
model(*args, **kwargs)
20+
21+
22+
if __name__ == "__main__":
23+
main()

facto/inputgen/argtuple/gen.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ def _apply_constraints_to_arg(self, arg, config: TensorConfig):
4848
# Create a copy of the argument with potentially modified constraints
4949
modified_arg = deepcopy(arg)
5050

51+
# Add rank constraints for tensor arguments when zerodim tensors are not allowed
52+
if not config.is_allowed(Condition.ALLOW_ZERODIM):
53+
if arg.type.is_tensor():
54+
rank_constraint = cp.Rank.Ge(lambda deps: 1)
55+
modified_arg.constraints = modified_arg.constraints + [rank_constraint]
56+
elif arg.type.is_tensor_list():
57+
rank_constraint = cp.Rank.Ge(lambda deps, length, ix: 1)
58+
modified_arg.constraints = modified_arg.constraints + [rank_constraint]
59+
5160
# Add size constraints for tensor arguments when empty tensors are not allowed
5261
if not config.is_allowed(Condition.ALLOW_EMPTY):
5362
if arg.type.is_tensor() or arg.type.is_tensor_list():
@@ -91,12 +100,14 @@ def gen_tuple(
91100
return posargs, inkwargs, outargs
92101

93102
def gen(
94-
self, *, valid: bool = True, out: bool = False
103+
self, *, valid: bool = True, out: bool = False, verbose: bool = False
95104
) -> Generator[
96105
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
97106
]:
98107
engine = MetaArgTupleEngine(self._modified_spec, out=out)
99108
for meta_tuple in engine.gen(valid=valid):
109+
if verbose:
110+
print(f"Generated meta_tuple: {[str(x) for x in meta_tuple]}")
100111
yield self.gen_tuple(meta_tuple, out=out)
101112

102113
def gen_errors(

facto/inputgen/argument/gen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,27 +246,27 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
246246
)
247247

248248
t = torch.randint(
249-
low=low, high=high, size=size, dtype=dtype, generator=torch_rng
249+
low=low, high=high, size=size, dtype=torch.float, generator=torch_rng
250250
)
251251
if not self.space.contains(0):
252252
if high > 0:
253253
pos = torch.randint(
254254
low=max(1, low),
255255
high=high,
256256
size=size,
257-
dtype=dtype,
257+
dtype=torch.float,
258258
generator=torch_rng,
259259
)
260260
else:
261261
pos = torch.randint(
262-
low=low, high=0, size=size, dtype=dtype, generator=torch_rng
262+
low=low, high=0, size=size, dtype=torch.float, generator=torch_rng
263263
)
264264
t = torch.where(t == 0, pos, t)
265265

266266
if dtype in dt._int:
267-
return t
267+
return t.to(dtype)
268268
if dtype in dt._floating:
269-
return t / FLOAT_RESOLUTION
269+
return (t / FLOAT_RESOLUTION).to(dtype)
270270
raise ValueError(f"Unsupported Dtype: {dtype}")
271271

272272

facto/inputgen/utils/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
class Condition(str, Enum):
13+
ALLOW_ZERODIM = "zerodim"
1314
ALLOW_EMPTY = "empty"
1415
ALLOW_TRANSPOSED = "transposed"
1516
ALLOW_PERMUTED = "permuted"
@@ -23,6 +24,7 @@ def __init__(self, device="cpu", disallow_dtypes=None, **conditions):
2324
self.device = device
2425
self.disallow_dtypes = disallow_dtypes or []
2526
self.conditions = {condition: False for condition in Condition}
27+
self.conditions[Condition.ALLOW_ZERODIM] = True # allow zerodim by default
2628
for condition, value in conditions.items():
2729
if condition in self.conditions:
2830
self.conditions[condition] = value

facto/modelgen/__init__.py

Whitespace-only changes.

facto/modelgen/gen.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, Generator, List, Optional, Tuple
8+
9+
import torch.nn as nn
10+
11+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
12+
from facto.inputgen.specs.model import Spec
13+
from facto.inputgen.utils.config import TensorConfig
14+
15+
16+
def separate_forward_and_model_inputs(
17+
spec: Spec, args: List[Any], kwargs: Dict[str, Any]
18+
) -> Tuple[List[Any], Dict[str, Any], List[Any], Dict[str, Any]]:
19+
"""
20+
Separate forward inputs from model parameters using FACTO's ArgType system.
21+
22+
Args:
23+
spec: The operation specification containing argument type information
24+
args: All positional arguments
25+
kwargs: All keyword arguments
26+
27+
Returns:
28+
Tuple of (forward_args, forward_kwargs, model_args, model_kwargs)
29+
"""
30+
forward_args = []
31+
model_args = []
32+
33+
forward_kwargs = {}
34+
model_kwargs = {}
35+
36+
for i, inarg in enumerate(spec.inspec):
37+
if inarg.kw:
38+
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
39+
forward_kwargs[inarg.name] = kwargs[inarg.name]
40+
else:
41+
model_kwargs[inarg.name] = kwargs[inarg.name]
42+
else:
43+
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
44+
forward_args.append(args[i])
45+
else:
46+
model_args.append(args[i])
47+
48+
return forward_args, forward_kwargs, model_args, model_kwargs
49+
50+
51+
def combine_forward_and_model_inputs(
52+
spec: Spec,
53+
forward_args: Tuple[Any],
54+
forward_kwargs: Dict[str, Any],
55+
model_args: Tuple[Any],
56+
model_kwargs: Dict[str, Any],
57+
) -> Tuple[List[Any], Dict[str, Any]]:
58+
"""
59+
Combine forward inputs with model parameters using FACTO's ArgType system.
60+
61+
Args:
62+
spec: The operation specification containing argument type information
63+
args: All positional arguments
64+
kwargs: All keyword arguments
65+
model_args: All model parameters
66+
model_kwargs: All model keyword parameters
67+
68+
Returns:
69+
Tuple of (args, kwargs)
70+
"""
71+
combined_args = []
72+
combined_kwargs = {}
73+
74+
forward_args_ix = 0
75+
model_args_ix = 0
76+
77+
# Iterate over the input specification
78+
for ix, inarg in enumerate(spec.inspec):
79+
if inarg.kw:
80+
# If the argument is a keyword argument, check if it's a tensor or tensor list
81+
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
82+
combined_kwargs[inarg.name] = forward_kwargs[inarg.name]
83+
else:
84+
combined_kwargs[inarg.name] = model_kwargs[inarg.name]
85+
else:
86+
# If the argument is a positional argument, check if it's a tensor or tensor list
87+
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
88+
combined_args.append(forward_args[forward_args_ix])
89+
forward_args_ix += 1
90+
else:
91+
combined_args.append(model_args[model_args_ix])
92+
model_args_ix += 1
93+
94+
return combined_args, combined_kwargs
95+
96+
97+
class OpModel(nn.Module):
98+
"""
99+
A PyTorch model that wraps a torch aten operation.
100+
101+
This class creates a simple model that applies a given torch operation
102+
to its inputs in the forward pass.
103+
"""
104+
105+
def __init__(
106+
self, op: Any, spec: Spec, op_name: str = "", *model_args, **model_kwargs
107+
):
108+
"""
109+
Initialize the OpModel.
110+
111+
Args:
112+
op: The torch aten operation to wrap
113+
op_name: Optional name for the operation (for debugging/logging)
114+
*model_args: Positional model parameters
115+
**model_kwargs: Keyword model parameters
116+
"""
117+
super().__init__()
118+
self.op = op
119+
self.op_name = op_name or str(op)
120+
self.spec = spec
121+
self.model_args = model_args
122+
self.model_kwargs = model_kwargs
123+
124+
def forward(self, *args, **kwargs) -> Any:
125+
"""
126+
Forward pass that applies the wrapped operation to the inputs.
127+
128+
Args:
129+
*args: Positional arguments to pass to the operation
130+
**kwargs: Keyword arguments to pass to the operation
131+
132+
Returns:
133+
The result of applying the operation to the inputs
134+
"""
135+
op_args, op_kwargs = combine_forward_and_model_inputs(
136+
self.spec, args, kwargs, self.model_args, self.model_kwargs
137+
)
138+
return self.op(*op_args, **op_kwargs)
139+
140+
def __repr__(self) -> str:
141+
return f"OpModel(op={self.op_name})"
142+
143+
144+
class OpModelGenerator:
145+
"""
146+
Generator that creates OpModel instances with appropriate inputs for testing.
147+
148+
This class takes a torch operation and its specification, then uses
149+
ArgumentTupleGenerator to create OpModel instances along with valid
150+
inputs for the forward function. It automatically separates tensor inputs
151+
from non-tensor parameters for ExecuTorch compatibility using FACTO's ArgType system.
152+
"""
153+
154+
def __init__(self, op: Any, spec: Spec, config: Optional[TensorConfig] = None):
155+
"""
156+
Initialize the OpModelGenerator.
157+
158+
Args:
159+
op: The torch aten operation to wrap in models
160+
spec: The specification for the operation's arguments
161+
config: Optional tensor configuration for input generation
162+
"""
163+
self.op = op
164+
self.spec = spec
165+
self.config = config
166+
self.arg_generator = ArgumentTupleGenerator(spec, config)
167+
168+
def gen(
169+
self,
170+
*,
171+
valid: bool = True,
172+
verbose: bool = False,
173+
max_count: Optional[int] = None,
174+
) -> Generator[Tuple[OpModel, List[Any], Dict[str, Any]], None, None]:
175+
"""
176+
Generate OpModel instances with corresponding inputs.
177+
178+
Args:
179+
valid: Whether to generate valid inputs (default: True)
180+
max_count: Maximum number of models to generate (default: None for unlimited)
181+
182+
Yields:
183+
Tuple containing:
184+
- OpModel instance wrapping the operation
185+
- List of positional arguments for forward()
186+
- Dict of keyword arguments for forward()
187+
"""
188+
count = 0
189+
for args, kwargs, _ in self.arg_generator.gen(
190+
valid=valid, out=False, verbose=verbose
191+
):
192+
if max_count is not None and count >= max_count:
193+
break
194+
195+
# Separate tensor inputs from non-tensor parameters
196+
forward_args, forward_kwargs, model_args, model_kwargs = (
197+
separate_forward_and_model_inputs(self.spec, args, kwargs)
198+
)
199+
200+
# Create model instance
201+
model = OpModel(
202+
self.op, self.spec, self.spec.op, *model_args, **model_kwargs
203+
)
204+
205+
yield model, forward_args, forward_kwargs
206+
count += 1
207+
208+
def test_model_with_inputs(
209+
self, model: OpModel, args: List[Any], kwargs: Dict[str, Any]
210+
) -> Tuple[bool, Optional[Any], Optional[Exception]]:
211+
"""
212+
Test a model with given inputs and return the result.
213+
214+
Args:
215+
model: The OpModel to test
216+
args: Positional arguments for the model
217+
kwargs: Keyword arguments for the model
218+
219+
Returns:
220+
Tuple containing:
221+
- Boolean indicating success/failure
222+
- The output if successful, None if failed
223+
- The exception if failed, None if successful
224+
"""
225+
try:
226+
output = model(*args, **kwargs)
227+
return True, output, None
228+
except Exception as e:
229+
return False, None, e
230+
231+
def __repr__(self) -> str:
232+
return f"OpModelGenerator(op={self.spec.op})"

0 commit comments

Comments
 (0)