Skip to content

Commit 2e2ce0b

Browse files
authored
Add kernel selector
Differential Revision: D77616329 Pull Request resolved: #2534
1 parent 378e179 commit 2e2ce0b

File tree

1 file changed

+240
-0
lines changed

1 file changed

+240
-0
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <cpuinfo.h>
9+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h>
10+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h>
11+
#include <optional>
12+
#include <stdexcept>
13+
#include <unordered_map>
14+
15+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
16+
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
17+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h>
18+
#endif // TORCHAO_ENABLE_ARM_NEON_DOT
19+
#endif // TORCHAO_BUILD_CPU_AARCH64
20+
21+
namespace torchao::ops::groupwise_lowbit_weight_lut {
22+
23+
/**
24+
* @brief A thread-unsafe registration table for kernel configurations.
25+
*
26+
* This table maps a combination of a weight format (header) and a CPU
27+
* microarchitecture to a specific UKernelConfig.
28+
*/
29+
struct UKernelConfigRegistrationTable {
30+
private:
31+
using Key = std::pair<torchao::ops::PackedWeightsHeader, cpuinfo_uarch>;
32+
struct KeyHasher {
33+
std::size_t operator()(const Key& k) const {
34+
return std::hash<torchao::ops::PackedWeightsHeader>()(k.first) ^
35+
std::hash<int>()(static_cast<int>(k.second));
36+
}
37+
};
38+
std::unordered_map<Key, UKernelConfig, KeyHasher> registration_table_;
39+
inline Key make_key(
40+
torchao::ops::PackedWeightsHeader header,
41+
cpuinfo_uarch uarch) const {
42+
return std::make_pair(header, uarch);
43+
}
44+
45+
public:
46+
// resgist a kernel config for a given format and uarch.
47+
void register_ukernel_config(
48+
PackedWeightsFormat format,
49+
cpuinfo_uarch uarch,
50+
UKernelConfig config) {
51+
auto header = format.to_packed_weights_header();
52+
auto key = make_key(header, uarch);
53+
if (registration_table_.find(key) != registration_table_.end()) {
54+
throw std::runtime_error(
55+
"UKernelConfig is already registered for this format");
56+
}
57+
config.validate();
58+
registration_table_[key] = config;
59+
}
60+
// get the kernel config for a given format and uarch.
61+
std::optional<UKernelConfig> get_ukernel_config(
62+
torchao::ops::PackedWeightsHeader header,
63+
cpuinfo_uarch uarch) const {
64+
auto key = make_key(header, uarch);
65+
auto it = registration_table_.find(key);
66+
if (it == registration_table_.end()) {
67+
return std::nullopt;
68+
}
69+
return it->second;
70+
}
71+
};
72+
73+
void log_registration(PackedWeightsFormat format, std::string description) {
74+
// Logging is only supported in ATen mode
75+
#ifdef USE_ATEN
76+
LOG(INFO) << "Registering ukernel config for groupwise_lowbit_weight_lut"
77+
<< std::endl
78+
<< "\tDescription: " << description << std::endl
79+
<< "\tformat.type=" << static_cast<int>(format.type) << std::endl
80+
<< "\tformat.weight_nbit=" << format.weight_nbit << std::endl
81+
<< "\tformat.has_bias=" << format.has_bias << std::endl
82+
<< "\tformat.has_scales=" << format.has_scales << std::endl
83+
<< "\tformat.lut_group_size=" << format.lut_group_size << std::endl
84+
<< "\tformat.scale_group_size=" << format.scale_group_size
85+
<< "\tformat.nr=" << format.nr << std::endl
86+
<< "\tformat.kr=" << format.kr << std::endl
87+
<< "\tformat.sr=" << format.sr << std::endl
88+
<< std::endl;
89+
#endif // USE_ATEN
90+
}
91+
92+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
93+
/**
94+
* @brief Registers all available AArch64 kernels for a given format.
95+
*
96+
* @tparam weight_nbit The bit-width of the weights.
97+
* @tparam has_scales Whether the packed buffer contains scale factors.
98+
* @param table The registration table to add the kernel config to.
99+
* @param format The format header describing the weights.
100+
* @param uarch The target CPU microarchitecture.
101+
*/
102+
template <int weight_nbit>
103+
void register_ukernel_config(
104+
UKernelConfigRegistrationTable& table,
105+
PackedWeightsFormat format,
106+
cpuinfo_uarch uarch) {
107+
if (!cpuinfo_initialize()) {
108+
throw std::runtime_error("Failed to initialize cpuinfo!");
109+
}
110+
if (!cpuinfo_has_arm_v8()) {
111+
// This CPU doesn't support the kernel, so do nothing.
112+
return;
113+
}
114+
115+
check_format(
116+
format,
117+
torchao::ops::PackedWeightsType::groupwise_lowbit_weight_lut,
118+
weight_nbit);
119+
int preferred_alignment = 16;
120+
121+
namespace kernel_api =
122+
torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut;
123+
124+
using kernel_fn_ptr_t =
125+
decltype(&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, true>);
126+
kernel_fn_ptr_t kernel_dispatcher;
127+
128+
if (format.has_scales) {
129+
kernel_dispatcher =
130+
&kernel_api::kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/true>;
131+
} else {
132+
kernel_dispatcher =
133+
&kernel_api::
134+
kernel_lowbit_1x4x32_f32<weight_nbit, /*has_scales=*/false>;
135+
}
136+
if (format.nr == 4 && format.kr == 32 && format.sr == 8) {
137+
log_registration(format, "lut: kernel_lowbit_1x4x32_f32");
138+
constexpr int nr = 4;
139+
constexpr int kr = 32;
140+
constexpr int sr = 8;
141+
constexpr int mr = 1;
142+
constexpr int m_step = 1;
143+
constexpr int n_step = 4;
144+
145+
auto uk = UKernelConfig::make(
146+
/*preferred_alignment=*/preferred_alignment,
147+
/*n_step=*/n_step,
148+
/*nr=*/format.nr,
149+
/*kr=*/format.kr,
150+
/*sr=*/format.sr,
151+
/*weight_nbit=*/format.weight_nbit,
152+
/*has_scales=*/format.has_scales,
153+
/*has_bias=*/format.has_bias,
154+
/*packed_weights_size_fn_type=*/
155+
&kernel_api::packed_weights_size<weight_nbit, nr, kr, sr>,
156+
/*pack_weights_fn_type=*/
157+
&kernel_api::
158+
pack_weights_for_groupwise_lut_kernel<weight_nbit, nr, kr, sr>,
159+
/*configs=*/{});
160+
161+
uk.configs[0] = UKernelConfig::group_config_type(
162+
{m_step,
163+
mr,
164+
&kernel_api::packed_activations_size,
165+
&kernel_api::packed_activations_offset,
166+
&kernel_api::pack_activations<mr, kr, sr>,
167+
kernel_dispatcher});
168+
169+
// Resgister the kernel config.
170+
table.register_ukernel_config(format, uarch, std::move(uk));
171+
}
172+
}
173+
#endif // TORCHAO_BUILD_CPU_AARCH64
174+
175+
/**
176+
* @brief Selects the best UKernelConfig for the given format header.
177+
*
178+
* This function is the main entry point for the op. It manages a static
179+
* registration table and, if a kernel is not already registered for the
180+
* current CPU, it will perform the registration.
181+
*
182+
* @tparam weight_nbit The bit-width of the weights.
183+
* @param header A header describing the packed weight format.
184+
* @return The appropriate UKernelConfig for the current environment.
185+
*/
186+
template <int weight_nbit>
187+
UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) {
188+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
189+
// Static table ensures we only register kernels once per session.
190+
static UKernelConfigRegistrationTable table;
191+
192+
if (!cpuinfo_initialize()) {
193+
throw std::runtime_error("Failed to initialize cpuinfo!");
194+
}
195+
196+
auto uarch = cpuinfo_uarch_unknown;
197+
198+
auto ukernel = table.get_ukernel_config(header, uarch);
199+
if (ukernel.has_value()) {
200+
return ukernel.value();
201+
}
202+
203+
// Create a new format object from the header.
204+
auto format = PackedWeightsFormat::from_packed_weights_header(header);
205+
206+
register_ukernel_config<weight_nbit>(table, format, uarch);
207+
208+
ukernel = table.get_ukernel_config(header, uarch);
209+
assert(ukernel.has_value() && "Kernel registration failed for the current CPU microarchitecture.");
210+
return ukernel.value();
211+
#else
212+
throw std::runtime_error(
213+
"select_ukernel_config for groupwise_lowbit_weight_lut is only supported "
214+
"when TORCHAO_BUILD_CPU_AARCH64 is defined.");
215+
#endif
216+
}
217+
218+
template <int weight_nbit>
219+
PackedWeightsFormat select_packed_weights_format(
220+
std::optional<std::string> target,
221+
int scale_group_size,
222+
int lut_group_size,
223+
bool has_scales,
224+
bool has_bias) {
225+
if (!target) {
226+
return PackedWeightsFormat(
227+
torchao::ops::PackedWeightsType::groupwise_lowbit_weight_lut,
228+
weight_nbit,
229+
scale_group_size,
230+
lut_group_size,
231+
has_scales,
232+
has_bias,
233+
/*nr*/ 4,
234+
/*kr*/ 32,
235+
/*sr*/ 8);
236+
}
237+
throw std::runtime_error("No packed_weights_format was selected");
238+
}
239+
240+
} // namespace torchao::ops::groupwise_lowbit_weight_lut

0 commit comments

Comments
 (0)