Skip to content

Commit 21ceb8e

Browse files
authored
Track API usage (#2706)
1 parent 715ea9f commit 21ceb8e

File tree

11 files changed

+98
-8
lines changed

11 files changed

+98
-8
lines changed

torchao/float8/float8_linear_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from typing import Callable, List, Optional, Union
99

10+
import torch
1011
import torch.nn as nn
1112

1213
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
@@ -101,6 +102,7 @@ def convert_to_float8_training(
101102
Returns:
102103
nn.Module: The modified module with swapped linear layers.
103104
"""
105+
torch._C._log_api_usage_once("torchao.float8.convert_to_float8_training")
104106
if config is None:
105107
config = Float8LinearConfig()
106108

torchao/float8/fsdp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
3939

4040
from torchao.float8.float8_linear import Float8Linear
4141

42+
torch._C._log_api_usage_once(
43+
"torchao.float8.precompute_float8_dynamic_scale_for_fsdp"
44+
)
45+
4246
float8_linears: List[Float8Linear] = [
4347
m
4448
for m in module.modules()

torchao/optim/adam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
bf16_stochastic_round=bf16_stochastic_round,
234234
is_adamw=False,
235235
)
236+
torch._C._log_api_usage_once("torchao.optim.Adam8bit")
236237

237238
@staticmethod
238239
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -263,6 +264,7 @@ def __init__(
263264
bf16_stochastic_round=bf16_stochastic_round,
264265
is_adamw=False,
265266
)
267+
torch._C._log_api_usage_once("torchao.optim.Adam4bit")
266268

267269
@staticmethod
268270
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -293,6 +295,7 @@ def __init__(
293295
bf16_stochastic_round=bf16_stochastic_round,
294296
is_adamw=False,
295297
)
298+
torch._C._log_api_usage_once("torchao.optim.AdamFp8")
296299

297300
@staticmethod
298301
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -323,6 +326,7 @@ def __init__(
323326
bf16_stochastic_round=bf16_stochastic_round,
324327
is_adamw=True,
325328
)
329+
torch._C._log_api_usage_once("torchao.optim.AdamW8bit")
326330

327331
@staticmethod
328332
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -353,6 +357,7 @@ def __init__(
353357
bf16_stochastic_round=bf16_stochastic_round,
354358
is_adamw=True,
355359
)
360+
torch._C._log_api_usage_once("torchao.optim.AdamW4bit")
356361

357362
@staticmethod
358363
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -383,6 +388,7 @@ def __init__(
383388
bf16_stochastic_round=bf16_stochastic_round,
384389
is_adamw=True,
385390
)
391+
torch._C._log_api_usage_once("torchao.optim.AdamWFp8")
386392

387393
@staticmethod
388394
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):

torchao/quantization/pt2e/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,9 +1266,6 @@ def _convert_to_reference_decomposed_fx(
12661266
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
12671267
12681268
"""
1269-
torch._C._log_api_usage_once(
1270-
"quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
1271-
)
12721269
return _convert_fx(
12731270
graph_module,
12741271
is_reference=True,

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def calibrate(model, data_loader):
106106

107107
return torch_prepare_pt2e(model, quantizer)
108108

109-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
109+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e")
110110
original_graph_meta = model.meta
111111
node_name_to_scope = _get_node_name_to_scope(model)
112112
# TODO: check qconfig_mapping to make sure conv and bn are both configured
@@ -192,7 +192,7 @@ def train_loop(model, train_data):
192192

193193
return torch_prepare_qat_pt2e(model, quantizer)
194194

195-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
195+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e")
196196
original_graph_meta = model.meta
197197
node_name_to_scope = _get_node_name_to_scope(model)
198198
model = quantizer.transform_for_annotation(model)
@@ -304,7 +304,7 @@ def convert_pt2e(
304304

305305
return torch_convert_pt2e(model, use_reference_representation, fold_quantize)
306306

307-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
307+
torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e")
308308
if not isinstance(use_reference_representation, bool):
309309
raise ValueError(
310310
"Unexpected argument type for `use_reference_representation`, "

torchao/quantization/qat/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def __init__(
146146
self.__post_init__()
147147

148148
def __post_init__(self):
149+
torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig")
149150
self.step = self.step.lower()
150151
all_step_values = [s.value for s in QATStep]
151152
if self.step not in all_step_values:
@@ -377,6 +378,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
377378
"""
378379

379380
def __init__(self, quantizers: List[TwoStepQuantizer]):
381+
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
380382
self.quantizers = quantizers
381383

382384
def prepare(
@@ -403,6 +405,8 @@ def initialize_fake_quantizers(
403405
:class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase`
404406
in the model based on the provided example inputs.
405407
"""
408+
torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers")
409+
406410
# avoid circular dependencies
407411
from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer
408412

torchao/quantization/qat/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
*args,
6666
**kwargs,
6767
)
68+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
6869
if weight_config is not None:
6970
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
7071
else:
@@ -148,6 +149,9 @@ def __init__(
148149
zero_point_precision: torch.dtype = torch.int32,
149150
) -> None:
150151
super().__init__()
152+
torch._C._log_api_usage_once(
153+
"torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
154+
)
151155
self.bit_width = 4
152156
self.group_size: int = group_size
153157
self.scale_precision: torch.dtype = scale_precision

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class IntxFakeQuantizer(FakeQuantizerBase):
9898

9999
def __init__(self, config: IntxFakeQuantizeConfig):
100100
super().__init__()
101+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
101102
self.config = config
102103
self.enabled = True
103104
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
*args,
8282
**kwargs,
8383
)
84+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
8485
# initialize activation fake quantizer
8586
if activation_config is not None:
8687
self.activation_fake_quantizer = FakeQuantizerBase.from_config(
@@ -210,6 +211,9 @@ def __init__(
210211
scales_precision: torch.dtype = torch.float32,
211212
) -> None:
212213
super().__init__()
214+
torch._C._log_api_usage_once(
215+
"torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
216+
)
213217
self.groupsize: int = groupsize
214218
self.padding_allowed: bool = padding_allowed
215219
self.precision: torch.dtype = precision
@@ -413,6 +417,9 @@ def __init__(
413417
scales_precision: torch.dtype = torch.bfloat16,
414418
) -> None:
415419
super().__init__()
420+
torch._C._log_api_usage_once(
421+
"torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
422+
)
416423
assert inner_k_tiles in [2, 4, 8]
417424
assert groupsize in [32, 64, 128, 256]
418425
self.inner_k_tiles = inner_k_tiles
@@ -594,6 +601,9 @@ def __init__(
594601
group_size: Optional[int] = 64,
595602
scale_precision: torch.dtype = torch.bfloat16,
596603
):
604+
torch._C._log_api_usage_once(
605+
"torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
606+
)
597607
if group_size is not None:
598608
weight_granularity = "per_group"
599609
else:

torchao/quantization/quant_api.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127

128128
logger = logging.getLogger(__name__)
129129

130+
# TODO: revisit this list?
130131
__all__ = [
131132
"swap_conv2d_1x1_to_linear",
132133
"Quantizer",
@@ -510,6 +511,8 @@ def quantize_(
510511
quantize_(m, int4_weight_only(group_size=32))
511512
512513
"""
514+
torch._C._log_api_usage_once("torchao.quantization.quantize_")
515+
513516
filter_fn = _is_linear if filter_fn is None else filter_fn
514517

515518
if isinstance(config, ModuleFqnToConfig):
@@ -619,6 +622,11 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
619622
act_mapping_type: MappingType = MappingType.ASYMMETRIC
620623
set_inductor_config: bool = True
621624

625+
def __post_init__(self):
626+
torch._C._log_api_usage_once(
627+
"torchao.quantization.Int8DynamicActivationInt4WeightConfig"
628+
)
629+
622630

623631
# for BC
624632
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
@@ -729,6 +737,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
729737
layout: Layout = QDQLayout()
730738

731739
def __post_init__(self):
740+
torch._C._log_api_usage_once(
741+
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
742+
)
732743
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
733744
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
734745
)
@@ -876,6 +887,11 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
876887
act_mapping_type: MappingType = MappingType.SYMMETRIC
877888
set_inductor_config: bool = True
878889

890+
def __post_init__(self):
891+
torch._C._log_api_usage_once(
892+
"torchao.quantization.Int4DynamicActivationInt4WeightConfig"
893+
)
894+
879895

880896
# for bc
881897
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
@@ -932,6 +948,11 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
932948
mode: Optional[str] = "weight_only"
933949
set_inductor_config: bool = True
934950

951+
def __post_init__(self):
952+
torch._C._log_api_usage_once(
953+
"torchao.quantization.GemliteUIntXWeightOnlyConfig"
954+
)
955+
935956

936957
# for BC
937958
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
@@ -1005,6 +1026,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
10051026
packing_format: PackingFormat = PackingFormat.PLAIN
10061027
VERSION: int = 1
10071028

1029+
def __post_init__(self):
1030+
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
1031+
10081032

10091033
# for BC
10101034
# TODO maybe change other callsites
@@ -1178,6 +1202,9 @@ class Int8WeightOnlyConfig(AOBaseConfig):
11781202
group_size: Optional[int] = None
11791203
set_inductor_config: bool = True
11801204

1205+
def __post_init__(self):
1206+
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
1207+
11811208

11821209
# for BC
11831210
int8_weight_only = Int8WeightOnlyConfig
@@ -1334,6 +1361,11 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
13341361
weight_only_decode: bool = False
13351362
set_inductor_config: bool = True
13361363

1364+
def __post_init__(self):
1365+
torch._C._log_api_usage_once(
1366+
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
1367+
)
1368+
13371369

13381370
# for BC
13391371
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
@@ -1438,6 +1470,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14381470
set_inductor_config: bool = True
14391471
version: int = 2
14401472

1473+
def __post_init__(self):
1474+
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
1475+
14411476

14421477
# for BC
14431478
float8_weight_only = Float8WeightOnlyConfig
@@ -1586,9 +1621,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15861621
version: int = 2
15871622

15881623
def __post_init__(self):
1624+
torch._C._log_api_usage_once(
1625+
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
1626+
)
15891627
if self.mm_config is None:
15901628
self.mm_config = Float8MMConfig(use_fast_accum=True)
1591-
15921629
activation_granularity, weight_granularity = _normalize_granularity(
15931630
self.granularity
15941631
)
@@ -1705,6 +1742,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
17051742
activation_dtype: torch.dtype = e5m2_dtype
17061743
weight_dtype: torch.dtype = e4m3_dtype
17071744

1745+
def __post_init__(self):
1746+
torch._C._log_api_usage_once(
1747+
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
1748+
)
1749+
17081750

17091751
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
17101752
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
@@ -1756,6 +1798,11 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
17561798
mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True)
17571799
set_inductor_config: bool = True
17581800

1801+
def __post_init__(self):
1802+
torch._C._log_api_usage_once(
1803+
"torchao.quantization.Float8StaticActivationFloat8WeightConfig"
1804+
)
1805+
17591806

17601807
# for bc
17611808
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
@@ -1836,6 +1883,9 @@ class UIntXWeightOnlyConfig(AOBaseConfig):
18361883
use_hqq: bool = False
18371884
set_inductor_config: bool = True
18381885

1886+
def __post_init__(self):
1887+
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
1888+
18391889

18401890
# for BC
18411891
uintx_weight_only = UIntXWeightOnlyConfig
@@ -1934,6 +1984,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
19341984
layout: Layout = QDQLayout()
19351985

19361986
def __post_init__(self):
1987+
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
19371988
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
19381989
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
19391990
)
@@ -2007,6 +2058,9 @@ class FPXWeightOnlyConfig(AOBaseConfig):
20072058
mbits: int
20082059
set_inductor_config: bool = True
20092060

2061+
def __post_init__(self):
2062+
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
2063+
20102064

20112065
# for BC
20122066
fpx_weight_only = FPXWeightOnlyConfig
@@ -2138,6 +2192,9 @@ class ModuleFqnToConfig(AOBaseConfig):
21382192
default_factory=dict
21392193
)
21402194

2195+
def __post_init__(self):
2196+
torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig")
2197+
21412198

21422199
def _module_fqn_to_config_handler(
21432200
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig

0 commit comments

Comments
 (0)