diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index 456b0ca160..182d1c3312 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -84,3 +84,20 @@ def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor): assert indices.dim() == 1 num_out = indices.shape[0] return torch.empty(num_out, k, dtype=torch.float32, device="meta") + + +for weight_nbit in range(1, 5): + + @impl(torchao_lib, f"_linear_groupwise_{weight_nbit}bit_weight_with_lut", "Meta") + def _( + activations: Tensor, + packed_weights: Tensor, + scale_group_size: int, + lut_group_size: int, + n: int, + k: int, + ): + assert activations.dim() == 2 + m, k_ = activations.shape + assert k_ == k + return torch.empty(m, n, dtype=activations.dtype, device="meta")