Skip to content

Commit 65cdd3e

Browse files
committed
Align Int4Tensor implementation details with the design of Float8Tensor
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1 parent bfe34b5 commit 65cdd3e

File tree

8 files changed

+682
-375
lines changed

8 files changed

+682
-375
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 5 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,11 @@
1010
from typing import Tuple
1111

1212
import torch
13-
import torch.nn as nn
14-
import torch.nn.functional as F
1513
from torch.testing._internal import common_utils
1614
from torch.testing._internal.common_utils import (
17-
TestCase,
1815
run_tests,
1916
)
2017

21-
from torchao.prototype.moe_quant.utils import MoEQuantConfig
2218
from torchao.quantization import (
2319
Float8DynamicActivationFloat8WeightConfig,
2420
Float8WeightOnlyConfig,
@@ -28,6 +24,7 @@
2824
)
2925
from torchao.quantization.quantize_.common import KernelPreference
3026
from torchao.quantization.utils import compute_error
27+
from torchao.testing.utils import TorchAOIntegrationTestCase
3128
from torchao.utils import (
3229
TORCH_VERSION_AT_LEAST_2_8,
3330
_is_fbgemm_genai_gpu_available,
@@ -39,66 +36,6 @@
3936
torch._dynamo.config.cache_size_limit = 128
4037

4138

42-
class Experts(nn.Module):
43-
def __init__(
44-
self,
45-
num_local_experts: int,
46-
dim: int,
47-
hidden_dim: int,
48-
dtype: torch.dtype,
49-
device: torch.device,
50-
) -> None:
51-
super().__init__()
52-
53-
self.num_local_experts = num_local_experts
54-
self.dim = dim
55-
56-
self.w1: nn.Parameter = nn.Parameter(
57-
torch.randn(
58-
num_local_experts,
59-
dim,
60-
hidden_dim,
61-
dtype=dtype,
62-
device=device,
63-
)
64-
)
65-
66-
self.w2: nn.Parameter = nn.Parameter(
67-
torch.randn(
68-
num_local_experts,
69-
hidden_dim,
70-
dim,
71-
dtype=dtype,
72-
device=device,
73-
)
74-
)
75-
76-
self.w3: nn.Parameter = nn.Parameter(
77-
torch.randn(
78-
num_local_experts,
79-
dim,
80-
hidden_dim,
81-
dtype=dtype,
82-
device=device,
83-
)
84-
)
85-
86-
def forward(
87-
self,
88-
routed_in_egD: torch.Tensor, # noqa: N803
89-
) -> torch.Tensor:
90-
e = self.num_local_experts
91-
D = self.dim
92-
93-
x_egD = routed_in_egD.view(e, -1, D)
94-
95-
middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3)
96-
out_egD = torch.bmm(middle_out_egF, self.w2)
97-
out_egD = out_egD.view(-1, D)
98-
99-
return out_egD
100-
101-
10239
class ToyLinearModel(torch.nn.Module):
10340
def __init__(self, in_features, out_features):
10441
super().__init__()
@@ -115,7 +52,7 @@ def forward(self, x):
11552
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
11653
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
11754
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
118-
class TestFloat8Tensor(TestCase):
55+
class TestFloat8Tensor(TorchAOIntegrationTestCase):
11956
def setUp(self):
12057
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
12158

@@ -340,45 +277,8 @@ def test_slice_preserves_aliasing(self, granularity):
340277

341278
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
342279
def test_slice_and_copy_similar_to_vllm(self, granularity):
343-
# making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
344-
# the test is similar to the linked code, but with some hardcoded arguments
345-
# and does not use tensor parallelism
346-
347-
dtype = torch.bfloat16
348-
device = "cuda"
349280
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
350-
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
351-
quantize_(l, config)
352-
353-
# high level, we do a narrow for both param.data and the loaded_weights
354-
# and do inplace copy_ to copy from the loaded_weights into param.data
355-
356-
# simulate loaded_weight
357-
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
358-
# making the weight different
359-
dummy_l.weight = torch.nn.Parameter(
360-
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
361-
requires_grad=False,
362-
)
363-
quantize_(dummy_l, config)
364-
365-
output_dim = 0
366-
shard_size = 512
367-
for tp_rank in [0, 1]:
368-
start_idx = tp_rank * shard_size
369-
param = l.weight
370-
param_data = param.data
371-
param_data = param_data.narrow(output_dim, start_idx, shard_size)
372-
orig_value = param_data.qdata[0][0].item()
373-
loaded_weight = dummy_l.weight
374-
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
375-
376-
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
377-
assert orig_value != loaded_weight.qdata[0][0]
378-
param_data.copy_(loaded_weight)
379-
# making sure param.data is updated to loaded_weight
380-
assert param_data.qdata[0][0] == loaded_weight.qdata[0][0]
381-
assert param_data.scale[0] == loaded_weight.scale[0]
281+
self._test_slice_and_copy_similar_to_vllm(config)
382282

383283
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
384284
def test_bmm(self):
@@ -494,122 +394,9 @@ def test_cat(self, granularity, sizes):
494394

495395
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
496396
def test_moe_weight_reshape_ops(self):
497-
"""This is testing the op call sequence in saving and loading quantization
498-
checkpoints in llama-models for llama4
499-
(https://github.com/meta-llama/llama-models/tree/main/models/llama4)
500-
"""
501-
# only per row quantization is supported for bmm
502397
granularity = PerRow()
503-
dtype = torch.bfloat16
504-
device = "cuda"
505-
506-
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
507-
moe_config = MoEQuantConfig(bmm_config)
508-
509-
batch_size = 4
510-
num_experts = 2
511-
input_dim = 64
512-
dim = 128
513-
hidden_dim = 256
514-
515-
moe1 = Experts(num_experts, dim, hidden_dim, dtype, device)
516-
moe2 = Experts(num_experts, dim, hidden_dim, dtype, device)
517-
moe_combined = Experts(num_experts, dim, 2 * hidden_dim, dtype, device)
518-
input = torch.randn(batch_size, input_dim, dim, dtype=dtype, device=device)
519-
520-
moes = [moe1, moe2]
521-
522-
for moe in moes:
523-
moe(input)
524-
525-
def filter_fn(module, fqn):
526-
return isinstance(module, Experts)
527-
528-
# need to transpose before quantizing
529-
moe.w1 = torch.nn.Parameter(
530-
moe.w1.transpose(1, 2).contiguous(), requires_grad=False
531-
)
532-
moe.w2 = torch.nn.Parameter(
533-
moe.w2.transpose(1, 2).contiguous(), requires_grad=False
534-
)
535-
moe.w3 = torch.nn.Parameter(
536-
moe.w3.transpose(1, 2).contiguous(), requires_grad=False
537-
)
538-
539-
quantize_(moe, moe_config, filter_fn=filter_fn)
540-
541-
# make sure it runs
542-
before = moe(input)
543-
544-
# transposing for resharding support since only 2D resharding is supported
545-
new_last_dim = moe.w1.shape[-2]
546-
moe.w1 = torch.nn.Parameter(
547-
moe.w1.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
548-
)
549-
new_last_dim = moe.w2.shape[-2]
550-
moe.w2 = torch.nn.Parameter(
551-
moe.w2.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
552-
)
553-
new_last_dim = moe.w3.shape[-2]
554-
moe.w3 = torch.nn.Parameter(
555-
moe.w3.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
556-
)
557-
558-
moe.w1 = torch.nn.Parameter(
559-
moe.w1.unflatten(0, (num_experts, -1)).squeeze(dim=0),
560-
requires_grad=False,
561-
)
562-
moe.w2 = torch.nn.Parameter(
563-
moe.w2.unflatten(0, (num_experts, -1)).squeeze(dim=0),
564-
requires_grad=False,
565-
)
566-
moe.w3 = torch.nn.Parameter(
567-
moe.w3.unflatten(0, (num_experts, -1)).squeeze(dim=0),
568-
requires_grad=False,
569-
)
570-
571-
# transpose again to recover the original weights
572-
moe.w1 = torch.nn.Parameter(moe.w1.transpose(1, 2), requires_grad=False)
573-
moe.w2 = torch.nn.Parameter(moe.w2.transpose(1, 2), requires_grad=False)
574-
moe.w3 = torch.nn.Parameter(moe.w3.transpose(1, 2), requires_grad=False)
575-
576-
# make sure it runs
577-
after = moe(input)
578-
579-
self.assertEqual(before, after)
580-
581-
state_dicts = [moe1.state_dict(), moe2.state_dict()]
582-
# align the scale parameter so they can be concatenated
583-
for key in ["w1", "w2", "w3"]:
584-
weights = [st[key] for st in state_dicts]
585-
for i in range(1, len(weights)):
586-
weights[i].scale = weights[0].scale
587-
588-
def process_key(key: str) -> torch.Tensor:
589-
tensors = [s[key] for s in state_dicts]
590-
# Note: we have a hacky implementation for cat in user codebase
591-
# since it is not implemented correctly before
592-
if key == "w2":
593-
return torch.cat(tensors, dim=-1)
594-
else:
595-
return torch.cat(tensors, dim=-2)
596-
597-
new_state_dict = {}
598-
for key in ["w1", "w2", "w3"]:
599-
new_state_dict[key] = process_key(key)
600-
601-
moe_combined.w1 = torch.nn.Parameter(
602-
moe_combined.w1.transpose(1, 2), requires_grad=False
603-
)
604-
moe_combined.w2 = torch.nn.Parameter(
605-
moe_combined.w2.transpose(1, 2), requires_grad=False
606-
)
607-
moe_combined.w3 = torch.nn.Parameter(
608-
moe_combined.w3.transpose(1, 2), requires_grad=False
609-
)
610-
moe_combined.load_state_dict(new_state_dict, assign=True)
611-
# make sure it runs
612-
moe_combined(input)
398+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
399+
self._test_moe_weight_reshape_ops(config)
613400

614401

615402
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

0 commit comments

Comments
 (0)