diff --git a/facto/inputgen/argument/gen.py b/facto/inputgen/argument/gen.py index caffb31..02891a5 100644 --- a/facto/inputgen/argument/gen.py +++ b/facto/inputgen/argument/gen.py @@ -8,13 +8,13 @@ from dataclasses import dataclass from typing import Optional, Tuple +import facto.utils.dtypes as dt import torch from facto.inputgen.argument.engine import MetaArg from facto.inputgen.utils.config import Condition, TensorConfig from facto.inputgen.utils.random_manager import seeded_random_manager from facto.inputgen.variable.gen import VariableGenerator from facto.inputgen.variable.space import VariableSpace -from torch.testing._internal.common_dtype import floating_types, integral_types FLOAT_RESOLUTION = 8 @@ -218,10 +218,10 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor: low=0, high=2, size=size, dtype=dtype, generator=torch_rng ) - if dtype in integral_types(): + if dtype in dt._int: low = math.ceil(low) high = math.floor(high) + 1 - elif dtype in floating_types(): + elif dtype in dt._floating: low = math.ceil(FLOAT_RESOLUTION * low) high = math.floor(FLOAT_RESOLUTION * high) + 1 else: @@ -263,9 +263,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor: ) t = torch.where(t == 0, pos, t) - if dtype in integral_types(): + if dtype in dt._int: return t - if dtype in floating_types(): + if dtype in dt._floating: return t / FLOAT_RESOLUTION raise ValueError(f"Unsupported Dtype: {dtype}") diff --git a/facto/inputgen/attribute/engine.py b/facto/inputgen/attribute/engine.py index df8fc2a..3acf72f 100644 --- a/facto/inputgen/attribute/engine.py +++ b/facto/inputgen/attribute/engine.py @@ -12,7 +12,11 @@ from facto.inputgen.attribute.solve import AttributeSolver from facto.inputgen.specs.model import Constraint from facto.inputgen.variable.gen import VariableGenerator -from facto.inputgen.variable.type import ScalarDtype, sort_values_of_type +from facto.inputgen.variable.type import ( + ScalarDtype, + sort_values_of_type, + SUPPORTED_TENSOR_DTYPES, +) class AttributeEngine(AttributeSolver): @@ -33,7 +37,7 @@ def gen(self, focus: Attribute, *args): num = 2 elif self.attribute == focus: if self.attribute == Attribute.DTYPE: - num = 8 + num = len(SUPPORTED_TENSOR_DTYPES) else: num = 6 else: diff --git a/facto/inputgen/variable/type.py b/facto/inputgen/variable/type.py index 0a5f124..eadd88b 100644 --- a/facto/inputgen/variable/type.py +++ b/facto/inputgen/variable/type.py @@ -35,12 +35,12 @@ def __str__(self): torch.int64, torch.float32, torch.float64, + torch.float16, + torch.bfloat16, # The following types are not supported yet, but we should support them soon: - # torch.float16, # torch.complex32, # torch.complex64, # torch.complex128, - # torch.bfloat16, ] diff --git a/facto/specdb/db.py b/facto/specdb/db.py index 44ac52b..33b3436 100644 --- a/facto/specdb/db.py +++ b/facto/specdb/db.py @@ -6,8 +6,8 @@ import math -import facto.specdb.dtypes as dt import facto.specdb.function as fn +import facto.utils.dtypes as dt import torch from facto.inputgen.argument.type import ArgType from facto.inputgen.specs.model import ( diff --git a/facto/specdb/dtypes.py b/facto/utils/dtypes.py similarity index 81% rename from facto/specdb/dtypes.py rename to facto/utils/dtypes.py index dc20698..e383f4f 100644 --- a/facto/specdb/dtypes.py +++ b/facto/utils/dtypes.py @@ -8,13 +8,12 @@ _int = [torch.uint8, torch.int8, torch.short, torch.int, torch.long] _int_and_bool = [torch.bool] + _int -_floating = [torch.float, torch.double] +_floating = [torch.float16, torch.bfloat16, torch.float, torch.double] _real = _int + _floating _real_and_bool = [torch.bool] + _int + _floating -_floating_and_half = [torch.half] + _floating _complex = [torch.chalf, torch.cfloat, torch.cdouble] _quant = [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] -_all = [torch.bool] + _int + _floating_and_half + _complex + _quant +_all = _real_and_bool + _complex + _quant def can_cast_from(t): diff --git a/test/specdb/test_specdb_cpu.py b/test/specdb/test_specdb_cpu.py index 1d93654..b4edec4 100644 --- a/test/specdb/test_specdb_cpu.py +++ b/test/specdb/test_specdb_cpu.py @@ -22,6 +22,10 @@ def test_all_ops_cpu(self): "split_with_sizes_copy.default", ] + # "cdist" not implemented for 'Half' on CPU + # "pdist" not implemented for 'Half' on CPU + skip_ops += ["_cdist_forward.default", "_pdist_forward.default"] + self._run_all_ops(skip_ops=skip_ops) diff --git a/test/specdb/test_specdb_mps.py b/test/specdb/test_specdb_mps.py index 60de098..8e4642f 100644 --- a/test/specdb/test_specdb_mps.py +++ b/test/specdb/test_specdb_mps.py @@ -51,6 +51,9 @@ def test_all_ops_mps(self): "var.dim", ] + # ConvTranspose 3D with BF16 or FP16 types is not supported on MPS + skip_ops += ["convolution.default"] + config = TensorConfig(device="mps", disallow_dtypes=[torch.float64]) self._run_all_ops(config=config, skip_ops=skip_ops)