Skip to content

Commit 6c65c63

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable half/bfloat16 dtypes (#37)
Summary: Pull Request resolved: #37 imported-using-ghimport Test Plan: Imported from OSS Rollback Plan: Reviewed By: digantdesai Differential Revision: D80468301 Pulled By: manuelcandales fbshipit-source-id: 837aa6e0f0a93d8228accdff7358977ea8ce8b44
1 parent 70c7036 commit 6c65c63

File tree

7 files changed

+23
-13
lines changed

7 files changed

+23
-13
lines changed

facto/inputgen/argument/gen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from dataclasses import dataclass
99
from typing import Optional, Tuple
1010

11+
import facto.utils.dtypes as dt
1112
import torch
1213
from facto.inputgen.argument.engine import MetaArg
1314
from facto.inputgen.utils.config import Condition, TensorConfig
1415
from facto.inputgen.utils.random_manager import seeded_random_manager
1516
from facto.inputgen.variable.gen import VariableGenerator
1617
from facto.inputgen.variable.space import VariableSpace
17-
from torch.testing._internal.common_dtype import floating_types, integral_types
1818

1919

2020
FLOAT_RESOLUTION = 8
@@ -218,10 +218,10 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
218218
low=0, high=2, size=size, dtype=dtype, generator=torch_rng
219219
)
220220

221-
if dtype in integral_types():
221+
if dtype in dt._int:
222222
low = math.ceil(low)
223223
high = math.floor(high) + 1
224-
elif dtype in floating_types():
224+
elif dtype in dt._floating:
225225
low = math.ceil(FLOAT_RESOLUTION * low)
226226
high = math.floor(FLOAT_RESOLUTION * high) + 1
227227
else:
@@ -263,9 +263,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
263263
)
264264
t = torch.where(t == 0, pos, t)
265265

266-
if dtype in integral_types():
266+
if dtype in dt._int:
267267
return t
268-
if dtype in floating_types():
268+
if dtype in dt._floating:
269269
return t / FLOAT_RESOLUTION
270270
raise ValueError(f"Unsupported Dtype: {dtype}")
271271

facto/inputgen/attribute/engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from facto.inputgen.attribute.solve import AttributeSolver
1313
from facto.inputgen.specs.model import Constraint
1414
from facto.inputgen.variable.gen import VariableGenerator
15-
from facto.inputgen.variable.type import ScalarDtype, sort_values_of_type
15+
from facto.inputgen.variable.type import (
16+
ScalarDtype,
17+
sort_values_of_type,
18+
SUPPORTED_TENSOR_DTYPES,
19+
)
1620

1721

1822
class AttributeEngine(AttributeSolver):
@@ -33,7 +37,7 @@ def gen(self, focus: Attribute, *args):
3337
num = 2
3438
elif self.attribute == focus:
3539
if self.attribute == Attribute.DTYPE:
36-
num = 8
40+
num = len(SUPPORTED_TENSOR_DTYPES)
3741
else:
3842
num = 6
3943
else:

facto/inputgen/variable/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def __str__(self):
3535
torch.int64,
3636
torch.float32,
3737
torch.float64,
38+
torch.float16,
39+
torch.bfloat16,
3840
# The following types are not supported yet, but we should support them soon:
39-
# torch.float16,
4041
# torch.complex32,
4142
# torch.complex64,
4243
# torch.complex128,
43-
# torch.bfloat16,
4444
]
4545

4646

facto/specdb/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import math
88

9-
import facto.specdb.dtypes as dt
109
import facto.specdb.function as fn
10+
import facto.utils.dtypes as dt
1111
import torch
1212
from facto.inputgen.argument.type import ArgType
1313
from facto.inputgen.specs.model import (

facto/specdb/dtypes.py renamed to facto/utils/dtypes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88

99
_int = [torch.uint8, torch.int8, torch.short, torch.int, torch.long]
1010
_int_and_bool = [torch.bool] + _int
11-
_floating = [torch.float, torch.double]
11+
_floating = [torch.float16, torch.bfloat16, torch.float, torch.double]
1212
_real = _int + _floating
1313
_real_and_bool = [torch.bool] + _int + _floating
14-
_floating_and_half = [torch.half] + _floating
1514
_complex = [torch.chalf, torch.cfloat, torch.cdouble]
1615
_quant = [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
17-
_all = [torch.bool] + _int + _floating_and_half + _complex + _quant
16+
_all = _real_and_bool + _complex + _quant
1817

1918

2019
def can_cast_from(t):

test/specdb/test_specdb_cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def test_all_ops_cpu(self):
2222
"split_with_sizes_copy.default",
2323
]
2424

25+
# "cdist" not implemented for 'Half' on CPU
26+
# "pdist" not implemented for 'Half' on CPU
27+
skip_ops += ["_cdist_forward.default", "_pdist_forward.default"]
28+
2529
self._run_all_ops(skip_ops=skip_ops)
2630

2731

test/specdb/test_specdb_mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def test_all_ops_mps(self):
5151
"var.dim",
5252
]
5353

54+
# ConvTranspose 3D with BF16 or FP16 types is not supported on MPS
55+
skip_ops += ["convolution.default"]
56+
5457
config = TensorConfig(device="mps", disallow_dtypes=[torch.float64])
5558
self._run_all_ops(config=config, skip_ops=skip_ops)
5659

0 commit comments

Comments
 (0)