Skip to content

Commit 761491c

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Add meta function for linear operation (groupwise lut kernel). (#2704)
Summary: Pull Request resolved: #2704 Reviewed By: metascroy Differential Revision: D79401683
1 parent 7757ab6 commit 761491c

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchao/experimental/op_lib.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,20 @@ def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor):
8484
assert indices.dim() == 1
8585
num_out = indices.shape[0]
8686
return torch.empty(num_out, k, dtype=torch.float32, device="meta")
87+
88+
89+
for weight_nbit in range(1, 5):
90+
91+
@impl(torchao_lib, f"_linear_groupwise_{weight_nbit}bit_weight_with_lut", "Meta")
92+
def _(
93+
activations: Tensor,
94+
packed_weights: Tensor,
95+
scale_group_size: int,
96+
lut_group_size: int,
97+
n: int,
98+
k: int,
99+
):
100+
assert activations.dim() == 2
101+
m, k_ = activations.shape
102+
assert k_ == k
103+
return torch.empty(m, n, dtype=activations.dtype, device="meta")

0 commit comments

Comments
 (0)