Skip to content

Commit c760f8f

Browse files
cthifacebook-github-bot
authored andcommitted
Fix tuning cache for f8f8bf16_rowwise_grouped on SM100 (#4843)
Summary: Pull Request resolved: #4843 X-link: facebookresearch/FBGEMM#1871 It would run SM90 kernels before by accident. Reviewed By: q10 Differential Revision: D82022651 fbshipit-source-id: ee739499faf61f73e5e9fbdb9d244cacb50c92e0
1 parent 210347c commit c760f8f

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,16 @@ Kernel_f8f8bf16_rowwise_grouped<InputType> get_kernel_via_tuning(
202202
const std::string shape_key = std::to_string(total_M) + "_" +
203203
std::to_string(max_N) + "_" + std::to_string(max_K) + "_" +
204204
std::to_string(G);
205-
const auto& kernels = get_f8f8bf16_rowwise_grouped_kernels<InputType>();
205+
206+
const auto& kernels = []() {
207+
const int arch = getDeviceArch();
208+
if (arch == 9) {
209+
return get_f8f8bf16_rowwise_grouped_kernels<InputType>();
210+
} else {
211+
return get_f8f8bf16_rowwise_grouped_kernels_sm100<InputType>();
212+
}
213+
}();
214+
206215
auto kernel = cache.findBestKernelMaybeAutotune(
207216
shape_key,
208217
kernels,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped_sm100/f8f8bf16_rowwise_grouped_manifest.cuh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,31 @@ at::Tensor f8f8bf16_rowwise_grouped_256_256_128_2_1_1_10_f(
154154
std::optional<at::Tensor> zero_start_index_M,
155155
std::optional<at::Tensor> M_sizes);
156156

157+
template <typename InputType>
158+
const std::
159+
unordered_map<std::string, Kernel_f8f8bf16_rowwise_grouped<InputType>>&
160+
get_f8f8bf16_rowwise_grouped_kernels_sm100() {
161+
static const std::
162+
unordered_map<std::string, Kernel_f8f8bf16_rowwise_grouped<InputType>>
163+
kernels = {
164+
{"f8f8bf16_rowwise_grouped_128_32_128_2_1_1_10_f",
165+
f8f8bf16_rowwise_grouped_128_32_128_2_1_1_10_f},
166+
{"f8f8bf16_rowwise_grouped_128_64_128_2_1_1_10_f",
167+
f8f8bf16_rowwise_grouped_128_64_128_2_1_1_10_f},
168+
{"f8f8bf16_rowwise_grouped_128_128_128_2_1_1_10_f",
169+
f8f8bf16_rowwise_grouped_128_128_128_2_1_1_10_f},
170+
{"f8f8bf16_rowwise_grouped_128_256_128_2_1_1_10_f",
171+
f8f8bf16_rowwise_grouped_128_256_128_2_1_1_10_f},
172+
{"f8f8bf16_rowwise_grouped_256_32_128_2_1_1_10_f",
173+
f8f8bf16_rowwise_grouped_256_32_128_2_1_1_10_f},
174+
{"f8f8bf16_rowwise_grouped_256_64_128_2_1_1_10_f",
175+
f8f8bf16_rowwise_grouped_256_64_128_2_1_1_10_f},
176+
{"f8f8bf16_rowwise_grouped_256_128_128_2_1_1_10_f",
177+
f8f8bf16_rowwise_grouped_256_128_128_2_1_1_10_f},
178+
{"f8f8bf16_rowwise_grouped_256_256_128_2_1_1_10_f",
179+
f8f8bf16_rowwise_grouped_256_256_128_2_1_1_10_f},
180+
};
181+
return kernels;
182+
}
183+
157184
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)