Skip to content

Commit a5f6aff

Browse files
authored
Add Aten operations
Differential Revision: D79119897 Pull Request resolved: #2664
1 parent 3c466f8 commit a5f6aff

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h>
8+
9+
#define DEFINE_PACK_OP(weight_nbit) \
10+
m.def( \
11+
"_pack_groupwise_" #weight_nbit \
12+
"bit_weight_with_lut(Tensor weight_qval_idxs, Tensor luts, int scale_group_size, int lut_group_size, Tensor? weight_scales, Tensor? bias, str? target) -> Tensor");
13+
14+
#define DEFINE_LINEAR_OP(weight_nbit) \
15+
m.def( \
16+
"_linear_groupwise_" #weight_nbit \
17+
"bit_weight_with_lut(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k) -> Tensor"); \
18+
m.def( \
19+
"_linear_groupwise_" #weight_nbit \
20+
"bit_weight_with_lut.out(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)");
21+
22+
#define DEFINE_PACK_CPU_IMPL(weight_nbit) \
23+
m.impl( \
24+
"_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \
25+
&pack_weights_with_lut_cpu<weight_nbit>);
26+
27+
#define DEFINE_PACK_META_IMPL(weight_nbit) \
28+
m.impl( \
29+
"_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \
30+
&pack_weights_with_lut_meta<weight_nbit>);
31+
32+
#define DEFINE_LINEAR_CPU_IMPL(weight_nbit) \
33+
m.impl( \
34+
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \
35+
&linear_cpu<weight_nbit>); \
36+
m.impl( \
37+
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut.out", \
38+
&linear_out_cpu<weight_nbit>);
39+
40+
#define DEFINE_LINEAR_META_IMPL(weight_nbit) \
41+
m.impl( \
42+
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \
43+
&linear_meta<weight_nbit>); \
44+
45+
46+
TORCH_LIBRARY_FRAGMENT(torchao, m) {
47+
DEFINE_PACK_OP(1);
48+
DEFINE_PACK_OP(2);
49+
DEFINE_PACK_OP(3);
50+
DEFINE_PACK_OP(4);
51+
52+
DEFINE_LINEAR_OP(1);
53+
DEFINE_LINEAR_OP(2);
54+
DEFINE_LINEAR_OP(3);
55+
DEFINE_LINEAR_OP(4);
56+
}
57+
58+
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
59+
DEFINE_PACK_CPU_IMPL(1);
60+
DEFINE_PACK_CPU_IMPL(2);
61+
DEFINE_PACK_CPU_IMPL(3);
62+
DEFINE_PACK_CPU_IMPL(4);
63+
64+
DEFINE_LINEAR_CPU_IMPL(1);
65+
DEFINE_LINEAR_CPU_IMPL(2);
66+
DEFINE_LINEAR_CPU_IMPL(3);
67+
DEFINE_LINEAR_CPU_IMPL(4);
68+
}
69+
70+
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
71+
DEFINE_PACK_META_IMPL(1);
72+
DEFINE_PACK_META_IMPL(2);
73+
DEFINE_PACK_META_IMPL(3);
74+
DEFINE_PACK_META_IMPL(4);
75+
76+
DEFINE_LINEAR_META_IMPL(1);
77+
DEFINE_LINEAR_META_IMPL(2);
78+
DEFINE_LINEAR_META_IMPL(3);
79+
DEFINE_LINEAR_META_IMPL(4);
80+
}

0 commit comments

Comments
 (0)