Skip to content

Commit 3515cb6

Browse files
authored
Update the op_-impl.h
Differential Revision: D77623746 Pull Request resolved: #2621
1 parent 0561b1a commit 3515cb6

File tree

1 file changed

+266
-0
lines changed

1 file changed

+266
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <ATen/Functions.h>
10+
#include <torch/library.h>
11+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
12+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h>
13+
#include <torchao/experimental/ops/library.h>
14+
#include <torchao/experimental/ops/packed_weights_header.h>
15+
#include <optional>
16+
#include <vector>
17+
18+
namespace {
19+
20+
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
21+
template <int weight_nbit>
22+
Tensor linear_out_cpu(
23+
const Tensor& activations,
24+
const Tensor& packed_weights,
25+
const int64_t& scale_group_size,
26+
const int64_t& lut_group_size,
27+
const int64_t& n,
28+
const int64_t& k,
29+
Tensor& out) {
30+
TORCHAO_CHECK(n >= 1, "n must be >= 1");
31+
TORCHAO_CHECK(k >= 1, "k must be >= 1");
32+
TORCHAO_CHECK(lut_group_size >= 1, "lut_group_size must be >= 1");
33+
34+
#ifdef USE_ATEN
35+
TORCHAO_CHECK(
36+
activations.dtype() == torch::kFloat32, "activations must be float32");
37+
#endif // USE_ATEN
38+
39+
TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D");
40+
int m = activations.size(0);
41+
int k_ = activations.size(1);
42+
TORCHAO_CHECK(
43+
k == k_, "activation shape is incompatible with packed weights.");
44+
45+
#ifdef USE_ATEN
46+
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
47+
#endif // USE_ATEN
48+
49+
// Explicit cast from int64_t to int is required for Executorch
50+
TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n});
51+
52+
TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D");
53+
#ifdef USE_ATEN
54+
TORCHAO_CHECK(
55+
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
56+
#endif // USE_ATEN
57+
TORCHAO_CHECK(
58+
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
59+
"packed_weights is not big enough to read the header.");
60+
auto header =
61+
torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr());
62+
63+
auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config<
64+
weight_nbit>(header);
65+
66+
torchao::ops::groupwise_lowbit_weight_lut::
67+
groupwise_lowbit_weight_lut_parallel_operator(
68+
uk,
69+
std::nullopt,
70+
out.mutable_data_ptr<float>(),
71+
m,
72+
n,
73+
k,
74+
scale_group_size,
75+
lut_group_size,
76+
packed_weights.const_data_ptr<int8_t>() +
77+
torchao::ops::PackedWeightsHeader::size(),
78+
activations.const_data_ptr<float>(),
79+
/*has_clamp=*/false,
80+
/*clamp_min=*/0.0,
81+
/*clamp_max=*/0.0);
82+
83+
return out;
84+
}
85+
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)
86+
87+
#ifdef USE_ATEN
88+
template <int weight_nbit>
89+
Tensor linear_cpu(
90+
const Tensor& activations,
91+
const Tensor& packed_weights,
92+
const int64_t& scale_group_size,
93+
const int64_t& lut_group_size,
94+
const int64_t& n,
95+
const int64_t& k) {
96+
Tensor output_tensor = torch::empty({}, torch::kFloat32);
97+
linear_out_cpu<weight_nbit>(
98+
activations,
99+
packed_weights,
100+
scale_group_size,
101+
lut_group_size,
102+
n,
103+
k,
104+
output_tensor);
105+
return output_tensor;
106+
}
107+
#endif // USE_ATEN
108+
109+
#ifdef USE_ATEN
110+
template <int weight_nbit>
111+
at::Tensor linear_meta(
112+
const at::Tensor& activations,
113+
const at::Tensor& packed_weights,
114+
const int64_t& scale_group_size,
115+
const int64_t& lut_group_size,
116+
const int64_t& n,
117+
const int64_t& k) {
118+
auto input_sizes = activations.sizes().vec();
119+
TORCH_CHECK(
120+
!input_sizes.empty() && input_sizes.back() == k,
121+
"The last dimension of `activations` is ",
122+
input_sizes.back(),
123+
" but it must be equal to k=",
124+
k);
125+
126+
auto output_sizes = input_sizes;
127+
output_sizes.back() = n;
128+
129+
return at::empty(output_sizes, activations.options());
130+
}
131+
#endif // USE_ATEN
132+
133+
#ifdef USE_ATEN
134+
template <int weight_nbit>
135+
Tensor pack_weights_with_lut_cpu(
136+
const Tensor& weight_qval_idxs,
137+
const Tensor& luts,
138+
int64_t scale_group_size,
139+
int64_t lut_group_size,
140+
const std::optional<Tensor>& weight_scales,
141+
const std::optional<Tensor>& bias,
142+
const std::optional<std::string>& target) {
143+
bool has_scales = weight_scales.has_value();
144+
bool has_bias = bias.has_value();
145+
146+
TORCHAO_CHECK(
147+
weight_qval_idxs.dtype() == torch::kUInt8,
148+
"weight_qval_idxs must be uint8");
149+
TORCHAO_CHECK(weight_qval_idxs.dim() == 2, "weight_qval_idxs must be 2D");
150+
int n = weight_qval_idxs.size(0);
151+
int k = weight_qval_idxs.size(1);
152+
TORCHAO_CHECK(lut_group_size >= 1, "lut_group_size must be >= 1");
153+
154+
TORCHAO_CHECK(
155+
luts.dtype() == torch::kFloat32,
156+
"luts must be float32"); // Changed to kFloat32
157+
TORCHAO_CHECK(lut_group_size % k == 0, "the number of luts must divide k");
158+
159+
TORCHAO_CHECK(
160+
luts.size(1) == (1 << weight_nbit),
161+
"luts must have 1 entry per quantization level");
162+
const float* scales_ptr = nullptr;
163+
164+
if (has_scales) {
165+
TORCHAO_CHECK(scale_group_size >= 1, "scale_group_size must be >= 1");
166+
TORCHAO_CHECK(
167+
weight_scales->dtype() == torch::kFloat32,
168+
"weight_scales must be float32");
169+
TORCHAO_CHECK(weight_scales->dim() == 1, "weight_scales must be 1D");
170+
scales_ptr = weight_scales.value().const_data_ptr<float>();
171+
}
172+
173+
const float* bias_ptr = nullptr;
174+
if (has_bias) {
175+
TORCHAO_CHECK(
176+
bias.value().dtype() == torch::kFloat32, "bias must be float32");
177+
TORCHAO_CHECK(bias.value().dim() == 1, "bias must be 1D");
178+
TORCHAO_CHECK(bias.value().size(0) == n, "expected 1 bias per row");
179+
bias_ptr = bias.value().const_data_ptr<float>();
180+
}
181+
182+
TORCHAO_CHECK(
183+
!target.has_value(), "target is not currently supported in pack_weights");
184+
185+
auto packed_weights_format =
186+
torchao::ops::groupwise_lowbit_weight_lut::select_packed_weights_format<
187+
weight_nbit>(
188+
target, scale_group_size, lut_group_size, has_scales, has_bias);
189+
190+
auto packed_weights_header = packed_weights_format.to_packed_weights_header();
191+
auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config<
192+
weight_nbit>(packed_weights_header);
193+
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
194+
uk.packed_weights_size(
195+
n,
196+
k,
197+
weight_nbit,
198+
scale_group_size,
199+
has_scales,
200+
has_bias,
201+
uk.nr,
202+
uk.kr,
203+
uk.sr);
204+
205+
Tensor packed_weights = torch::empty(
206+
{static_cast<int64_t>(packed_weight_data_size)}, torch::kInt8);
207+
packed_weights_header.write(packed_weights.mutable_data_ptr<int8_t>());
208+
209+
torchao::ops::groupwise_lowbit_weight_lut::pack_weights_operator(
210+
uk,
211+
packed_weights.mutable_data_ptr<int8_t>() +
212+
torchao::ops::PackedWeightsHeader::size(),
213+
n,
214+
k,
215+
scale_group_size,
216+
lut_group_size,
217+
weight_qval_idxs.const_data_ptr<uint8_t>(),
218+
scales_ptr,
219+
luts.const_data_ptr<float>(),
220+
bias_ptr);
221+
222+
return packed_weights;
223+
}
224+
#endif // USE_ATEN
225+
226+
#ifdef USE_ATEN
227+
template <int weight_nbit>
228+
Tensor pack_weights_with_lut_meta(
229+
const Tensor& weight_qval_idxs,
230+
const Tensor& luts,
231+
int64_t scale_group_size,
232+
int64_t lut_group_size,
233+
const std::optional<Tensor>& weight_scales,
234+
const std::optional<Tensor>& bias,
235+
const std::optional<std::string>& target) {
236+
bool has_bias = bias.has_value();
237+
bool has_scales = weight_scales.has_value();
238+
int n = weight_qval_idxs.size(0);
239+
int k = weight_qval_idxs.size(1);
240+
auto packed_weights_format =
241+
torchao::ops::groupwise_lowbit_weight_lut::select_packed_weights_format<
242+
weight_nbit>(
243+
target, scale_group_size, lut_group_size, has_scales, has_bias);
244+
auto packed_weights_header = packed_weights_format.to_packed_weights_header();
245+
auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config<
246+
weight_nbit>(packed_weights_header);
247+
248+
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
249+
uk.packed_weights_size(
250+
n,
251+
k,
252+
weight_nbit,
253+
scale_group_size,
254+
has_scales,
255+
has_bias,
256+
uk.nr,
257+
uk.kr,
258+
uk.sr);
259+
260+
auto options =
261+
torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
262+
return torch::empty({static_cast<int64_t>(packed_weight_data_size)}, options);
263+
}
264+
#endif // USE_ATEN
265+
266+
} // namespace

0 commit comments

Comments
 (0)