File tree Expand file tree Collapse file tree 3 files changed +23
-3
lines changed
Expand file tree Collapse file tree 3 files changed +23
-3
lines changed Original file line number Diff line number Diff line change 1616)
1717from torchao .float8 .fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
1818from torchao .float8 .inference import Float8MMConfig
19+ from torchao .float8 .types import FP8Granularity
1920from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
2021
2122if TORCH_VERSION_AT_LEAST_2_5 :
4142 # top level UX
4243 "convert_to_float8_training" ,
4344 "precompute_float8_dynamic_scale_for_fsdp" ,
45+ # types
46+ "FP8Granularity" ,
4447 # note: Float8Tensor and Float8Linear are not public APIs
4548]
Original file line number Diff line number Diff line change 1212import torch
1313
1414from torchao .float8 .float8_utils import is_row_major , pad_tensor_for_matmul
15+ from torchao .float8 .types import FP8Granularity
1516from torchao .quantization .granularity import (
1617 PerRow ,
1718 PerTensor ,
@@ -116,9 +117,6 @@ def _is_rowwise_scaled(x) -> bool:
116117 return x .block_size == (1 ,) * (x .dim () - 1 ) + (x .shape [- 1 ],)
117118
118119
119- FP8Granularity = Union [PerTensor , PerRow ]
120-
121-
122120def _normalize_granularity (
123121 granularity : Optional [
124122 Union [
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD 3-Clause license found in the
5+ # LICENSE file in the root directory of this source tree.
6+ """
7+ Common types for float8 quantization
8+ """
9+
10+ from __future__ import annotations
11+
12+ from typing import TYPE_CHECKING , Union
13+
14+ if TYPE_CHECKING :
15+ from torchao .quantization .granularity import PerRow , PerTensor
16+
17+
18+ # Define FP8Granularity type alias to break circular import dependencies
19+ FP8Granularity = Union ["PerTensor" , "PerRow" ]
You can’t perform that action at this time.
0 commit comments