Skip to content

Commit e700c81

Browse files
committed
Track API usage
1 parent 948ade1 commit e700c81

File tree

11 files changed

+98
-8
lines changed

11 files changed

+98
-8
lines changed

torchao/float8/float8_linear_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
from typing import Callable, List, Optional, Union
99

10+
import torch
1011
import torch.nn as nn
1112

1213
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
@@ -101,6 +102,7 @@ def convert_to_float8_training(
101102
Returns:
102103
nn.Module: The modified module with swapped linear layers.
103104
"""
105+
torch._C._log_api_usage_once("torchao.float8.convert_to_float8_training")
104106
if config is None:
105107
config = Float8LinearConfig()
106108

torchao/float8/fsdp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
3939

4040
from torchao.float8.float8_linear import Float8Linear
4141

42+
torch._C._log_api_usage_once(
43+
"torchao.float8.precompute_float8_dynamic_scale_for_fsdp"
44+
)
45+
4246
float8_linears: List[Float8Linear] = [
4347
m
4448
for m in module.modules()

torchao/optim/adam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
bf16_stochastic_round=bf16_stochastic_round,
234234
is_adamw=False,
235235
)
236+
torch._C._log_api_usage_once("torchao.optim.Adam8bit")
236237

237238
@staticmethod
238239
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -263,6 +264,7 @@ def __init__(
263264
bf16_stochastic_round=bf16_stochastic_round,
264265
is_adamw=False,
265266
)
267+
torch._C._log_api_usage_once("torchao.optim.Adam4bit")
266268

267269
@staticmethod
268270
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -293,6 +295,7 @@ def __init__(
293295
bf16_stochastic_round=bf16_stochastic_round,
294296
is_adamw=False,
295297
)
298+
torch._C._log_api_usage_once("torchao.optim.AdamFp8")
296299

297300
@staticmethod
298301
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -323,6 +326,7 @@ def __init__(
323326
bf16_stochastic_round=bf16_stochastic_round,
324327
is_adamw=True,
325328
)
329+
torch._C._log_api_usage_once("torchao.optim.AdamW8bit")
326330

327331
@staticmethod
328332
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -353,6 +357,7 @@ def __init__(
353357
bf16_stochastic_round=bf16_stochastic_round,
354358
is_adamw=True,
355359
)
360+
torch._C._log_api_usage_once("torchao.optim.AdamW4bit")
356361

357362
@staticmethod
358363
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
@@ -383,6 +388,7 @@ def __init__(
383388
bf16_stochastic_round=bf16_stochastic_round,
384389
is_adamw=True,
385390
)
391+
torch._C._log_api_usage_once("torchao.optim.AdamWFp8")
386392

387393
@staticmethod
388394
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):

torchao/quantization/pt2e/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,9 +1271,6 @@ def _convert_to_reference_decomposed_fx(
12711271
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
12721272
12731273
"""
1274-
torch._C._log_api_usage_once(
1275-
"quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
1276-
)
12771274
return _convert_fx(
12781275
graph_module,
12791276
is_reference=True,

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def calibrate(model, data_loader):
106106

107107
return torch_prepare_pt2e(model, quantizer)
108108

109-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
109+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e")
110110
original_graph_meta = model.meta
111111
node_name_to_scope = _get_node_name_to_scope(model)
112112
# TODO: check qconfig_mapping to make sure conv and bn are both configured
@@ -192,7 +192,7 @@ def train_loop(model, train_data):
192192

193193
return torch_prepare_qat_pt2e(model, quantizer)
194194

195-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
195+
torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e")
196196
original_graph_meta = model.meta
197197
node_name_to_scope = _get_node_name_to_scope(model)
198198
model = quantizer.transform_for_annotation(model)
@@ -309,7 +309,7 @@ def convert_pt2e(
309309

310310
return torch_convert_pt2e(model, use_reference_representation, fold_quantize)
311311

312-
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
312+
torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e")
313313
if not isinstance(use_reference_representation, bool):
314314
raise ValueError(
315315
"Unexpected argument type for `use_reference_representation`, "

torchao/quantization/qat/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
self.__post_init__()
145145

146146
def __post_init__(self):
147+
torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig")
147148
self.step = self.step.lower()
148149
all_step_values = [s.value for s in QATStep]
149150
if self.step not in all_step_values:
@@ -359,6 +360,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
359360
"""
360361

361362
def __init__(self, quantizers: List[TwoStepQuantizer]):
363+
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
362364
self.quantizers = quantizers
363365

364366
def prepare(
@@ -385,6 +387,8 @@ def initialize_fake_quantizers(
385387
:class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase`
386388
in the model based on the provided example inputs.
387389
"""
390+
torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers")
391+
388392
# avoid circular dependencies
389393
from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer
390394

torchao/quantization/qat/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
*args,
6666
**kwargs,
6767
)
68+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
6869
if weight_config is not None:
6970
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
7071
else:
@@ -148,6 +149,9 @@ def __init__(
148149
zero_point_precision: torch.dtype = torch.int32,
149150
) -> None:
150151
super().__init__()
152+
torch._C._log_api_usage_once(
153+
"torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
154+
)
151155
self.bit_width = 4
152156
self.group_size: int = group_size
153157
self.scale_precision: torch.dtype = scale_precision

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class IntxFakeQuantizer(FakeQuantizerBase):
6666

6767
def __init__(self, config: IntxFakeQuantizeConfig):
6868
super().__init__()
69+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
6970
self.config = config
7071
self.enabled = True
7172
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
*args,
8383
**kwargs,
8484
)
85+
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
8586
# initialize activation fake quantizer
8687
if activation_config is not None:
8788
self.activation_fake_quantizer = FakeQuantizerBase.from_config(
@@ -211,6 +212,9 @@ def __init__(
211212
scales_precision: torch.dtype = torch.float32,
212213
) -> None:
213214
super().__init__()
215+
torch._C._log_api_usage_once(
216+
"torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
217+
)
214218
self.groupsize: int = groupsize
215219
self.padding_allowed: bool = padding_allowed
216220
self.precision: torch.dtype = precision
@@ -414,6 +418,9 @@ def __init__(
414418
scales_precision: torch.dtype = torch.bfloat16,
415419
) -> None:
416420
super().__init__()
421+
torch._C._log_api_usage_once(
422+
"torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
423+
)
417424
assert inner_k_tiles in [2, 4, 8]
418425
assert groupsize in [32, 64, 128, 256]
419426
self.inner_k_tiles = inner_k_tiles
@@ -598,6 +605,9 @@ def __init__(
598605
group_size: Optional[int] = 64,
599606
scale_precision: torch.dtype = torch.bfloat16,
600607
):
608+
torch._C._log_api_usage_once(
609+
"torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
610+
)
601611
if group_size is not None:
602612
weight_granularity = "per_group"
603613
else:

torchao/quantization/quant_api.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133

134134
logger = logging.getLogger(__name__)
135135

136+
# TODO: revisit this list?
136137
__all__ = [
137138
"swap_conv2d_1x1_to_linear",
138139
"Quantizer",
@@ -619,6 +620,8 @@ def quantize_(
619620
quantize_(m, int4_weight_only(group_size=32))
620621
621622
"""
623+
torch._C._log_api_usage_once("torchao.quantization.quantize_")
624+
622625
filter_fn = _is_linear if filter_fn is None else filter_fn
623626

624627
if isinstance(config, ModuleFqnToConfig):
@@ -743,6 +746,11 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
743746
act_mapping_type: MappingType = MappingType.ASYMMETRIC
744747
set_inductor_config: bool = True
745748

749+
def __post_init__(self):
750+
torch._C._log_api_usage_once(
751+
"torchao.quantization.Int8DynamicActivationInt4WeightConfig"
752+
)
753+
746754

747755
# for BC
748756
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
@@ -854,6 +862,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
854862
layout: Layout = QDQLayout()
855863

856864
def __post_init__(self):
865+
torch._C._log_api_usage_once(
866+
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
867+
)
857868
assert TORCH_VERSION_AT_LEAST_2_6, (
858869
"Int8DynamicActivationIntxWeightConfig requires torch 2.6+"
859870
)
@@ -1004,6 +1015,11 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
10041015
act_mapping_type: MappingType = MappingType.SYMMETRIC
10051016
set_inductor_config: bool = True
10061017

1018+
def __post_init__(self):
1019+
torch._C._log_api_usage_once(
1020+
"torchao.quantization.Int4DynamicActivationInt4WeightConfig"
1021+
)
1022+
10071023

10081024
# for bc
10091025
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
@@ -1060,6 +1076,11 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
10601076
mode: Optional[str] = "weight_only"
10611077
set_inductor_config: bool = True
10621078

1079+
def __post_init__(self):
1080+
torch._C._log_api_usage_once(
1081+
"torchao.quantization.GemliteUIntXWeightOnlyConfig"
1082+
)
1083+
10631084

10641085
# for BC
10651086
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
@@ -1133,6 +1154,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11331154
packing_format: PackingFormat = PackingFormat.PLAIN
11341155
VERSION: int = 1
11351156

1157+
def __post_init__(self):
1158+
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
1159+
11361160

11371161
# for BC
11381162
# TODO maybe change other callsites
@@ -1305,6 +1329,9 @@ class Int8WeightOnlyConfig(AOBaseConfig):
13051329
group_size: Optional[int] = None
13061330
set_inductor_config: bool = True
13071331

1332+
def __post_init__(self):
1333+
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
1334+
13081335

13091336
# for BC
13101337
int8_weight_only = Int8WeightOnlyConfig
@@ -1461,6 +1488,11 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
14611488
weight_only_decode: bool = False
14621489
set_inductor_config: bool = True
14631490

1491+
def __post_init__(self):
1492+
torch._C._log_api_usage_once(
1493+
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
1494+
)
1495+
14641496

14651497
# for BC
14661498
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
@@ -1565,6 +1597,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
15651597
set_inductor_config: bool = True
15661598
version: int = 2
15671599

1600+
def __post_init__(self):
1601+
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
1602+
15681603

15691604
# for BC
15701605
float8_weight_only = Float8WeightOnlyConfig
@@ -1713,9 +1748,11 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
17131748
version: int = 2
17141749

17151750
def __post_init__(self):
1751+
torch._C._log_api_usage_once(
1752+
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
1753+
)
17161754
if self.mm_config is None:
17171755
self.mm_config = Float8MMConfig(use_fast_accum=True)
1718-
17191756
activation_granularity, weight_granularity = _normalize_granularity(
17201757
self.granularity
17211758
)
@@ -1832,6 +1869,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
18321869
activation_dtype: torch.dtype = e5m2_dtype
18331870
weight_dtype: torch.dtype = e4m3_dtype
18341871

1872+
def __post_init__(self):
1873+
torch._C._log_api_usage_once(
1874+
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
1875+
)
1876+
18351877

18361878
@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
18371879
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
@@ -1883,6 +1925,11 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
18831925
mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True)
18841926
set_inductor_config: bool = True
18851927

1928+
def __post_init__(self):
1929+
torch._C._log_api_usage_once(
1930+
"torchao.quantization.Float8StaticActivationFloat8WeightConfig"
1931+
)
1932+
18861933

18871934
# for bc
18881935
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
@@ -1963,6 +2010,9 @@ class UIntXWeightOnlyConfig(AOBaseConfig):
19632010
use_hqq: bool = False
19642011
set_inductor_config: bool = True
19652012

2013+
def __post_init__(self):
2014+
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
2015+
19662016

19672017
# for BC
19682018
uintx_weight_only = UIntXWeightOnlyConfig
@@ -2062,6 +2112,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
20622112
layout: Layout = QDQLayout()
20632113

20642114
def __post_init__(self):
2115+
torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig")
20652116
assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+"
20662117
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
20672118
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
@@ -2136,6 +2187,9 @@ class FPXWeightOnlyConfig(AOBaseConfig):
21362187
mbits: int
21372188
set_inductor_config: bool = True
21382189

2190+
def __post_init__(self):
2191+
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
2192+
21392193

21402194
# for BC
21412195
fpx_weight_only = FPXWeightOnlyConfig
@@ -2267,6 +2321,9 @@ class ModuleFqnToConfig(AOBaseConfig):
22672321
default_factory=dict
22682322
)
22692323

2324+
def __post_init__(self):
2325+
torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig")
2326+
22702327

22712328
def _module_fqn_to_config_handler(
22722329
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig

0 commit comments

Comments
 (0)