Skip to content

Commit 5f6306e

Browse files
committed
Align Int4Tensor implementation details with the design of Float8Tensor
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1 parent ed552a2 commit 5f6306e

File tree

5 files changed

+644
-204
lines changed

5 files changed

+644
-204
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from typing import Tuple
1111

1212
import torch
13-
import torch.nn as nn
14-
import torch.nn.functional as F
1513
from torch.testing._internal import common_utils
1614
from torch.testing._internal.common_utils import (
1715
TestCase,
@@ -28,6 +26,7 @@
2826
)
2927
from torchao.quantization.quantize_.common import KernelPreference
3028
from torchao.quantization.utils import compute_error
29+
from torchao.testing.model_architectures import Experts
3130
from torchao.utils import (
3231
TORCH_VERSION_AT_LEAST_2_8,
3332
_is_fbgemm_genai_gpu_available,
@@ -39,66 +38,6 @@
3938
torch._dynamo.config.cache_size_limit = 128
4039

4140

42-
class Experts(nn.Module):
43-
def __init__(
44-
self,
45-
num_local_experts: int,
46-
dim: int,
47-
hidden_dim: int,
48-
dtype: torch.dtype,
49-
device: torch.device,
50-
) -> None:
51-
super().__init__()
52-
53-
self.num_local_experts = num_local_experts
54-
self.dim = dim
55-
56-
self.w1: nn.Parameter = nn.Parameter(
57-
torch.randn(
58-
num_local_experts,
59-
dim,
60-
hidden_dim,
61-
dtype=dtype,
62-
device=device,
63-
)
64-
)
65-
66-
self.w2: nn.Parameter = nn.Parameter(
67-
torch.randn(
68-
num_local_experts,
69-
hidden_dim,
70-
dim,
71-
dtype=dtype,
72-
device=device,
73-
)
74-
)
75-
76-
self.w3: nn.Parameter = nn.Parameter(
77-
torch.randn(
78-
num_local_experts,
79-
dim,
80-
hidden_dim,
81-
dtype=dtype,
82-
device=device,
83-
)
84-
)
85-
86-
def forward(
87-
self,
88-
routed_in_egD: torch.Tensor, # noqa: N803
89-
) -> torch.Tensor:
90-
e = self.num_local_experts
91-
D = self.dim
92-
93-
x_egD = routed_in_egD.view(e, -1, D)
94-
95-
middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3)
96-
out_egD = torch.bmm(middle_out_egF, self.w2)
97-
out_egD = out_egD.view(-1, D)
98-
99-
return out_egD
100-
101-
10241
class ToyLinearModel(torch.nn.Module):
10342
def __init__(self, in_features, out_features):
10443
super().__init__()

0 commit comments

Comments
 (0)