Skip to content

Introduce OpModelGenerator #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/manuelcandales/24/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/minimal_modelgen_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from facto.inputgen.utils.config import TensorConfig
from facto.modelgen.gen import OpModelGenerator
from facto.specdb.db import SpecDictDB
from facto.utils.ops import get_op_overload


def main():
op_name = "add.Tensor"
spec = SpecDictDB[op_name]
op = get_op_overload(op_name)
config = TensorConfig(device="cpu", half_precision=False)
for model, args, kwargs in OpModelGenerator(op, spec, config).gen(verbose=True):
model(*args, **kwargs)


if __name__ == "__main__":
main()
13 changes: 12 additions & 1 deletion facto/inputgen/argtuple/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def _apply_constraints_to_arg(self, arg, config: TensorConfig):
# Create a copy of the argument with potentially modified constraints
modified_arg = deepcopy(arg)

# Add rank constraints for tensor arguments when zerodim tensors are not allowed
if not config.is_allowed(Condition.ALLOW_ZERODIM):
if arg.type.is_tensor():
rank_constraint = cp.Rank.Ge(lambda deps: 1)
modified_arg.constraints = modified_arg.constraints + [rank_constraint]
elif arg.type.is_tensor_list():
rank_constraint = cp.Rank.Ge(lambda deps, length, ix: 1)
modified_arg.constraints = modified_arg.constraints + [rank_constraint]

# Add size constraints for tensor arguments when empty tensors are not allowed
if not config.is_allowed(Condition.ALLOW_EMPTY):
if arg.type.is_tensor() or arg.type.is_tensor_list():
Expand Down Expand Up @@ -91,12 +100,14 @@ def gen_tuple(
return posargs, inkwargs, outargs

def gen(
self, *, valid: bool = True, out: bool = False
self, *, valid: bool = True, out: bool = False, verbose: bool = False
) -> Generator[
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
]:
engine = MetaArgTupleEngine(self._modified_spec, out=out)
for meta_tuple in engine.gen(valid=valid):
if verbose:
print(f"Generated meta_tuple: {[str(x) for x in meta_tuple]}")
yield self.gen_tuple(meta_tuple, out=out)

def gen_errors(
Expand Down
10 changes: 5 additions & 5 deletions facto/inputgen/argument/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,27 +246,27 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
)

t = torch.randint(
low=low, high=high, size=size, dtype=dtype, generator=torch_rng
low=low, high=high, size=size, dtype=torch.float, generator=torch_rng
)
if not self.space.contains(0):
if high > 0:
pos = torch.randint(
low=max(1, low),
high=high,
size=size,
dtype=dtype,
dtype=torch.float,
generator=torch_rng,
)
else:
pos = torch.randint(
low=low, high=0, size=size, dtype=dtype, generator=torch_rng
low=low, high=0, size=size, dtype=torch.float, generator=torch_rng
)
t = torch.where(t == 0, pos, t)

if dtype in dt._int:
return t
return t.to(dtype)
if dtype in dt._floating:
return t / FLOAT_RESOLUTION
return (t / FLOAT_RESOLUTION).to(dtype)
raise ValueError(f"Unsupported Dtype: {dtype}")


Expand Down
2 changes: 2 additions & 0 deletions facto/inputgen/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class Condition(str, Enum):
ALLOW_ZERODIM = "zerodim"
ALLOW_EMPTY = "empty"
ALLOW_TRANSPOSED = "transposed"
ALLOW_PERMUTED = "permuted"
Expand All @@ -23,6 +24,7 @@ def __init__(self, device="cpu", disallow_dtypes=None, **conditions):
self.device = device
self.disallow_dtypes = disallow_dtypes or []
self.conditions = {condition: False for condition in Condition}
self.conditions[Condition.ALLOW_ZERODIM] = True # allow zerodim by default
for condition, value in conditions.items():
if condition in self.conditions:
self.conditions[condition] = value
Expand Down
Empty file added facto/modelgen/__init__.py
Empty file.
232 changes: 232 additions & 0 deletions facto/modelgen/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Generator, List, Optional, Tuple

import torch.nn as nn

from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.specs.model import Spec
from facto.inputgen.utils.config import TensorConfig


def separate_forward_and_model_inputs(
spec: Spec, args: List[Any], kwargs: Dict[str, Any]
) -> Tuple[List[Any], Dict[str, Any], List[Any], Dict[str, Any]]:
"""
Separate forward inputs from model parameters using FACTO's ArgType system.

Args:
spec: The operation specification containing argument type information
args: All positional arguments
kwargs: All keyword arguments

Returns:
Tuple of (forward_args, forward_kwargs, model_args, model_kwargs)
"""
forward_args = []
model_args = []

forward_kwargs = {}
model_kwargs = {}

for i, inarg in enumerate(spec.inspec):
if inarg.kw:
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
forward_kwargs[inarg.name] = kwargs[inarg.name]
else:
model_kwargs[inarg.name] = kwargs[inarg.name]
else:
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
forward_args.append(args[i])
else:
model_args.append(args[i])

return forward_args, forward_kwargs, model_args, model_kwargs


def combine_forward_and_model_inputs(
spec: Spec,
forward_args: Tuple[Any],
forward_kwargs: Dict[str, Any],
model_args: Tuple[Any],
model_kwargs: Dict[str, Any],
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Combine forward inputs with model parameters using FACTO's ArgType system.

Args:
spec: The operation specification containing argument type information
args: All positional arguments
kwargs: All keyword arguments
model_args: All model parameters
model_kwargs: All model keyword parameters

Returns:
Tuple of (args, kwargs)
"""
combined_args = []
combined_kwargs = {}

forward_args_ix = 0
model_args_ix = 0

# Iterate over the input specification
for ix, inarg in enumerate(spec.inspec):
if inarg.kw:
# If the argument is a keyword argument, check if it's a tensor or tensor list
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
combined_kwargs[inarg.name] = forward_kwargs[inarg.name]
else:
combined_kwargs[inarg.name] = model_kwargs[inarg.name]
else:
# If the argument is a positional argument, check if it's a tensor or tensor list
if inarg.type.is_tensor() or inarg.type.is_tensor_list():
combined_args.append(forward_args[forward_args_ix])
forward_args_ix += 1
else:
combined_args.append(model_args[model_args_ix])
model_args_ix += 1

return combined_args, combined_kwargs


class OpModel(nn.Module):
"""
A PyTorch model that wraps a torch aten operation.

This class creates a simple model that applies a given torch operation
to its inputs in the forward pass.
"""

def __init__(
self, op: Any, spec: Spec, op_name: str = "", *model_args, **model_kwargs
):
"""
Initialize the OpModel.

Args:
op: The torch aten operation to wrap
op_name: Optional name for the operation (for debugging/logging)
*model_args: Positional model parameters
**model_kwargs: Keyword model parameters
"""
super().__init__()
self.op = op
self.op_name = op_name or str(op)
self.spec = spec
self.model_args = model_args
self.model_kwargs = model_kwargs

def forward(self, *args, **kwargs) -> Any:
"""
Forward pass that applies the wrapped operation to the inputs.

Args:
*args: Positional arguments to pass to the operation
**kwargs: Keyword arguments to pass to the operation

Returns:
The result of applying the operation to the inputs
"""
op_args, op_kwargs = combine_forward_and_model_inputs(
self.spec, args, kwargs, self.model_args, self.model_kwargs
)
return self.op(*op_args, **op_kwargs)

def __repr__(self) -> str:
return f"OpModel(op={self.op_name})"


class OpModelGenerator:
"""
Generator that creates OpModel instances with appropriate inputs for testing.

This class takes a torch operation and its specification, then uses
ArgumentTupleGenerator to create OpModel instances along with valid
inputs for the forward function. It automatically separates tensor inputs
from non-tensor parameters for ExecuTorch compatibility using FACTO's ArgType system.
"""

def __init__(self, op: Any, spec: Spec, config: Optional[TensorConfig] = None):
"""
Initialize the OpModelGenerator.

Args:
op: The torch aten operation to wrap in models
spec: The specification for the operation's arguments
config: Optional tensor configuration for input generation
"""
self.op = op
self.spec = spec
self.config = config
self.arg_generator = ArgumentTupleGenerator(spec, config)

def gen(
self,
*,
valid: bool = True,
verbose: bool = False,
max_count: Optional[int] = None,
) -> Generator[Tuple[OpModel, List[Any], Dict[str, Any]], None, None]:
"""
Generate OpModel instances with corresponding inputs.

Args:
valid: Whether to generate valid inputs (default: True)
max_count: Maximum number of models to generate (default: None for unlimited)

Yields:
Tuple containing:
- OpModel instance wrapping the operation
- List of positional arguments for forward()
- Dict of keyword arguments for forward()
"""
count = 0
for args, kwargs, _ in self.arg_generator.gen(
valid=valid, out=False, verbose=verbose
):
if max_count is not None and count >= max_count:
break

# Separate tensor inputs from non-tensor parameters
forward_args, forward_kwargs, model_args, model_kwargs = (
separate_forward_and_model_inputs(self.spec, args, kwargs)
)

# Create model instance
model = OpModel(
self.op, self.spec, self.spec.op, *model_args, **model_kwargs
)

yield model, forward_args, forward_kwargs
count += 1

def test_model_with_inputs(
self, model: OpModel, args: List[Any], kwargs: Dict[str, Any]
) -> Tuple[bool, Optional[Any], Optional[Exception]]:
"""
Test a model with given inputs and return the result.

Args:
model: The OpModel to test
args: Positional arguments for the model
kwargs: Keyword arguments for the model

Returns:
Tuple containing:
- Boolean indicating success/failure
- The output if successful, None if failed
- The exception if failed, None if successful
"""
try:
output = model(*args, **kwargs)
return True, output, None
except Exception as e:
return False, None, e

def __repr__(self) -> str:
return f"OpModelGenerator(op={self.spec.op})"
Loading