Skip to content

refactor common used toy model #2729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.testing.model_architectures import ToySingleLinearModel
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
Expand Down Expand Up @@ -62,32 +63,6 @@ def _int4wo_api(mod, **kwargs):
change_linear_weights_to_int4_woqtensors(mod, **kwargs)


class ToyLinearModel(torch.nn.Module):
"""Single linear for m * k * n problem size"""

def __init__(
self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"
):
super().__init__()
self.m = m
self.dtype = dtype
self.device = device
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(
dtype=self.dtype, device=self.device
)

def example_inputs(self):
return (
torch.randn(
self.m, self.linear.in_features, dtype=self.dtype, device=self.device
),
)

def forward(self, x):
x = self.linear(x)
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
Expand Down Expand Up @@ -151,7 +126,7 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(
m = ToySingleLinearModel(
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
).eval()
m_bf16 = copy.deepcopy(m)
Expand Down
14 changes: 2 additions & 12 deletions docs/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,9 @@ First, let's set up our toy model:

import copy
import torch
from torchao.testing.model_architectures import ToyMultiLinearModel

class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
model = ToyMultiLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")

# Optional: compile model for faster inference and generation
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
29 changes: 7 additions & 22 deletions docs/source/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Serialization and deserialization flow
======================================

Here is the serialization and deserialization flow::

import copy
import tempfile
import torch
Expand All @@ -16,23 +16,10 @@ Here is the serialization and deserialization flow::
quantize_,
Int4WeightOnlyConfig,
)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
from torchao.testing.model_architectures import ToyMultiLinearModel

dtype = torch.bfloat16
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
m = ToyMultiLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
print(f"original model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")

example_inputs = m.example_inputs(dtype=dtype, device="cuda")
Expand All @@ -46,7 +33,7 @@ Here is the serialization and deserialization flow::
state_dict = torch.load(f)

with torch.device("meta"):
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
m_loaded = ToyMultiLinearModel(1024, 1024, 1024).eval().to(dtype)

# `linear.weight` is nn.Parameter, so we check the type of `linear.weight.data`
print(f"type of weight before loading: {type(m_loaded.linear1.weight.data), type(m_loaded.linear2.weight.data)}")
Expand All @@ -62,7 +49,7 @@ What happens when serializing an optimized model?
To serialize an optimized model, we just need to call ``torch.save(m.state_dict(), f)``, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example:

original floating point model ``state_dict``::

{"linear1.weight": float_weight1, "linear2.weight": float_weight2}

quantized model ``state_dict``::
Expand All @@ -75,14 +62,14 @@ The size of the quantized model is typically going to be smaller to the original
original model size: 4.0 MB
quantized model size: 1.0625 MB


What happens when deserializing an optimized model?
===================================================
To deserialize an optimized model, we can initialize the floating point model in `meta <https://pytorch.org/docs/stable/meta.html>`__ device and then load the optimized ``state_dict`` with ``assign=True`` using `model.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`__::


with torch.device("meta"):
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
m_loaded = ToyMultiLinearModel(1024, 1024, 1024).eval().to(dtype)

print(f"type of weight before loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
m_loaded.load_state_dict(state_dict, assign=True)
Expand All @@ -97,5 +84,3 @@ We can also verify that the weight is properly loaded by checking the type of we

type of weight before loading: (<class 'torch.Tensor'>, <class 'torch.Tensor'>)
type of weight after loading: (<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>)


16 changes: 2 additions & 14 deletions scripts/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.testing.model_architectures import ToyMultiLinearModel
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
benchmark_model,
Expand All @@ -18,20 +19,7 @@
# | Set up model |
# ================


class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
model = ToyMultiLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")

# Optional: compile model for faster inference and generation
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
23 changes: 6 additions & 17 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_quantize_affine_float8,
choose_qparams_affine,
)
from torchao.testing.model_architectures import ToyMultiLinearModel
from torchao.utils import (
is_sm_at_least_89,
is_sm_at_least_90,
Expand All @@ -55,18 +56,6 @@
torch.manual_seed(0)


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
Expand Down Expand Up @@ -129,7 +118,7 @@ def test_fp8_linear_variants(
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyMultiLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
Expand Down Expand Up @@ -186,7 +175,7 @@ def test_per_row_with_float32(self):
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
model = ToyMultiLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
Expand All @@ -199,7 +188,7 @@ def test_per_row_with_float32(self):
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
model = ToyLinearModel(16, 32).to(device="cuda")
model = ToyMultiLinearModel(16, 32).to(device="cuda")

mode_map = {
"dynamic": partial(
Expand Down Expand Up @@ -231,7 +220,7 @@ def test_serialization(self, mode: str):

# Create a new model and load the state dict
with torch.device("meta"):
new_model = ToyLinearModel(16, 32)
new_model = ToyMultiLinearModel(16, 32)
if mode == "static":
quantize_(new_model, factory)
new_model.load_state_dict(loaded_state_dict, assign=True)
Expand Down Expand Up @@ -273,7 +262,7 @@ def test_serialization(self, mode: str):
)
def test_fp8_weight_dimension_warning(self):
# Create model with incompatible dimensions (not multiples of 16)
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights
model = ToyMultiLinearModel(10, 25).cuda() # 10x25 and 25x10 weights

# Set up logging capture
with self.assertLogs(
Expand Down
20 changes: 2 additions & 18 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,28 +2134,12 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):


class TestBenchmarkModel(unittest.TestCase):
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
from torchao.testing.model_architectures import ToyMultiLinearModel

def run_benchmark_model(self, device):
# params
dtype = torch.bfloat16
m = self.ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to(device)
m = self.ToyMultiLinearModel(1024, 1024, 1024).eval().to(dtype).to(device)
m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device=device)
m_bf16 = torch.compile(m_bf16, mode="max-autotune")
Expand Down
35 changes: 6 additions & 29 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,13 @@

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
from torchao.testing.model_architectures import ToyMultiLinearModel
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
_is_fbgemm_genai_gpu_available,
)


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 64, bias=False)

def example_inputs(
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
):
return [
torch.randn(
1, sequence_length, self.linear1.in_features, dtype=dtype, device=device
)
for j in range(batch_size)
]

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x


@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
Expand Down Expand Up @@ -77,7 +54,7 @@ def test_awq_functionality(self):
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m = ToyMultiLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = FbgemmConfig(
Expand Down Expand Up @@ -126,7 +103,7 @@ def test_awq_loading(self):
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m = ToyMultiLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
Expand Down Expand Up @@ -158,7 +135,7 @@ def test_awq_loading(self):
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
loaded_model = ToyMultiLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
loaded_model.load_state_dict(state_dict, assign=True)

m = torch.compile(m, fullgraph=True)
Expand Down Expand Up @@ -186,7 +163,7 @@ def test_awq_loading_vllm(self):
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m = ToyMultiLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
Expand Down Expand Up @@ -218,7 +195,7 @@ def test_awq_loading_vllm(self):
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
loaded_model = ToyMultiLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

Expand Down
Loading