From 9835f3a3e180afc162751b2147712c02759dd5c2 Mon Sep 17 00:00:00 2001 From: Zeyu Song Date: Mon, 11 Aug 2025 16:21:30 -0700 Subject: [PATCH] Add meta function for linear operation (groupwise lut kernel). (#2704) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/2704 Reviewed By: metascroy Differential Revision: D79401683 --- torchao/experimental/op_lib.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index 456b0ca160..182d1c3312 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -84,3 +84,20 @@ def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor): assert indices.dim() == 1 num_out = indices.shape[0] return torch.empty(num_out, k, dtype=torch.float32, device="meta") + + +for weight_nbit in range(1, 5): + + @impl(torchao_lib, f"_linear_groupwise_{weight_nbit}bit_weight_with_lut", "Meta") + def _( + activations: Tensor, + packed_weights: Tensor, + scale_group_size: int, + lut_group_size: int, + n: int, + k: int, + ): + assert activations.dim() == 2 + m, k_ = activations.shape + assert k_ == k + return torch.empty(m, n, dtype=activations.dtype, device="meta")