Skip to content

Enable half/bfloat16 dtypes #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions facto/inputgen/argument/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import facto.utils.dtypes as dt
import torch
from facto.inputgen.argument.engine import MetaArg
from facto.inputgen.utils.config import Condition, TensorConfig
from facto.inputgen.utils.random_manager import seeded_random_manager
from facto.inputgen.variable.gen import VariableGenerator
from facto.inputgen.variable.space import VariableSpace
from torch.testing._internal.common_dtype import floating_types, integral_types


FLOAT_RESOLUTION = 8
Expand Down Expand Up @@ -218,10 +218,10 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
low=0, high=2, size=size, dtype=dtype, generator=torch_rng
)

if dtype in integral_types():
if dtype in dt._int:
low = math.ceil(low)
high = math.floor(high) + 1
elif dtype in floating_types():
elif dtype in dt._floating:
low = math.ceil(FLOAT_RESOLUTION * low)
high = math.floor(FLOAT_RESOLUTION * high) + 1
else:
Expand Down Expand Up @@ -263,9 +263,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
)
t = torch.where(t == 0, pos, t)

if dtype in integral_types():
if dtype in dt._int:
return t
if dtype in floating_types():
if dtype in dt._floating:
return t / FLOAT_RESOLUTION
raise ValueError(f"Unsupported Dtype: {dtype}")

Expand Down
8 changes: 6 additions & 2 deletions facto/inputgen/attribute/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from facto.inputgen.attribute.solve import AttributeSolver
from facto.inputgen.specs.model import Constraint
from facto.inputgen.variable.gen import VariableGenerator
from facto.inputgen.variable.type import ScalarDtype, sort_values_of_type
from facto.inputgen.variable.type import (
ScalarDtype,
sort_values_of_type,
SUPPORTED_TENSOR_DTYPES,
)


class AttributeEngine(AttributeSolver):
Expand All @@ -33,7 +37,7 @@ def gen(self, focus: Attribute, *args):
num = 2
elif self.attribute == focus:
if self.attribute == Attribute.DTYPE:
num = 8
num = len(SUPPORTED_TENSOR_DTYPES)
else:
num = 6
else:
Expand Down
4 changes: 2 additions & 2 deletions facto/inputgen/variable/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __str__(self):
torch.int64,
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
# The following types are not supported yet, but we should support them soon:
# torch.float16,
# torch.complex32,
# torch.complex64,
# torch.complex128,
# torch.bfloat16,
]


Expand Down
2 changes: 1 addition & 1 deletion facto/specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import math

import facto.specdb.dtypes as dt
import facto.specdb.function as fn
import facto.utils.dtypes as dt
import torch
from facto.inputgen.argument.type import ArgType
from facto.inputgen.specs.model import (
Expand Down
5 changes: 2 additions & 3 deletions facto/specdb/dtypes.py → facto/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

_int = [torch.uint8, torch.int8, torch.short, torch.int, torch.long]
_int_and_bool = [torch.bool] + _int
_floating = [torch.float, torch.double]
_floating = [torch.float16, torch.bfloat16, torch.float, torch.double]
_real = _int + _floating
_real_and_bool = [torch.bool] + _int + _floating
_floating_and_half = [torch.half] + _floating
_complex = [torch.chalf, torch.cfloat, torch.cdouble]
_quant = [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
_all = [torch.bool] + _int + _floating_and_half + _complex + _quant
_all = _real_and_bool + _complex + _quant


def can_cast_from(t):
Expand Down
4 changes: 4 additions & 0 deletions test/specdb/test_specdb_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def test_all_ops_cpu(self):
"split_with_sizes_copy.default",
]

# "cdist" not implemented for 'Half' on CPU
# "pdist" not implemented for 'Half' on CPU
skip_ops += ["_cdist_forward.default", "_pdist_forward.default"]

self._run_all_ops(skip_ops=skip_ops)


Expand Down
3 changes: 3 additions & 0 deletions test/specdb/test_specdb_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def test_all_ops_mps(self):
"var.dim",
]

# ConvTranspose 3D with BF16 or FP16 types is not supported on MPS
skip_ops += ["convolution.default"]

config = TensorConfig(device="mps", disallow_dtypes=[torch.float64])
self._run_all_ops(config=config, skip_ops=skip_ops)

Expand Down