Skip to content

Commit 0a21a7a

Browse files
[Inference] remove useless code and fix bug (#71488)
* remove useless code and fix a bug
1 parent 27a9b82 commit 0a21a7a

File tree

3 files changed

+138
-440
lines changed

3 files changed

+138
-440
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,24 @@ namespace phi {
3636

3737
static std::vector<CutlassTileConfig> get_candidate_tiles(
3838
const bool is_weight_only,
39-
const bool is_weight_only_encoder,
4039
const bool simt_configs_only,
4140
const int sm,
4241
const int group_size,
4342
const bool is_moe) {
4443
VLOG(3) << "get_candidate_tiles sm: " << sm;
45-
std::vector<CutlassTileConfig> simt_configs{
46-
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
44+
if (simt_configs_only) {
45+
std::vector<CutlassTileConfig> simt_configs{
46+
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
47+
return simt_configs;
48+
} else if (!is_weight_only) {
49+
std::vector<CutlassTileConfig> square_configs{
50+
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
51+
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
52+
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
53+
};
54+
return square_configs;
55+
}
4756

48-
std::vector<CutlassTileConfig> square_configs{
49-
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
50-
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
51-
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
52-
};
5357
std::vector<CutlassTileConfig> quant_B_configs_sm70{
5458
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
5559
CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64,
@@ -92,27 +96,17 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
9296
quant_B_configs = quant_B_configs_sm70;
9397
break;
9498
}
95-
const std::vector<CutlassTileConfig> allowed_quant_B_configs =
96-
quant_B_configs;
97-
const std::vector<CutlassTileConfig> allowed_configs =
98-
is_weight_only ? allowed_quant_B_configs : square_configs;
99-
return simt_configs_only ? simt_configs : allowed_configs;
99+
return quant_B_configs;
100100
}
101101

102102
static std::vector<CutlassGemmConfig> get_candidate_configs(
103103
const int sm,
104104
const int group_size,
105105
const bool is_weight_only,
106-
const bool is_weight_only_encoder,
107106
const bool simt_configs_only,
108107
const bool is_moe) {
109-
std::vector<CutlassTileConfig> tiles =
110-
get_candidate_tiles(is_weight_only,
111-
is_weight_only_encoder,
112-
simt_configs_only,
113-
sm,
114-
group_size,
115-
is_moe);
108+
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(
109+
is_weight_only, simt_configs_only, sm, group_size, is_moe);
116110

117111
std::vector<CutlassGemmConfig> candidate_configs;
118112
const int min_stages = 2;

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,8 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
596596
cudaStream_t stream) {
597597
// VLOG(3)<<__PRETTY_FUNCTION__;
598598
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
599-
const bool is_weight_only_encoder = m >= 512 ? true : false;
600-
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(
601-
sm_, group_size, is_weight_only, is_weight_only_encoder, false, false);
599+
std::vector<CutlassGemmConfig> candidate_configs =
600+
get_candidate_configs(sm_, group_size, is_weight_only, false, false);
602601

603602
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
604603
// FFN.

0 commit comments

Comments
 (0)