@@ -36,20 +36,24 @@ namespace phi {
36
36
37
37
static std::vector<CutlassTileConfig> get_candidate_tiles (
38
38
const bool is_weight_only,
39
- const bool is_weight_only_encoder,
40
39
const bool simt_configs_only,
41
40
const int sm,
42
41
const int group_size,
43
42
const bool is_moe) {
44
43
VLOG (3 ) << " get_candidate_tiles sm: " << sm;
45
- std::vector<CutlassTileConfig> simt_configs{
46
- CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
44
+ if (simt_configs_only) {
45
+ std::vector<CutlassTileConfig> simt_configs{
46
+ CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
47
+ return simt_configs;
48
+ } else if (!is_weight_only) {
49
+ std::vector<CutlassTileConfig> square_configs{
50
+ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
51
+ CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
52
+ CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
53
+ };
54
+ return square_configs;
55
+ }
47
56
48
- std::vector<CutlassTileConfig> square_configs{
49
- CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
50
- CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
51
- CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
52
- };
53
57
std::vector<CutlassTileConfig> quant_B_configs_sm70{
54
58
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
55
59
CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64,
@@ -92,27 +96,17 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
92
96
quant_B_configs = quant_B_configs_sm70;
93
97
break ;
94
98
}
95
- const std::vector<CutlassTileConfig> allowed_quant_B_configs =
96
- quant_B_configs;
97
- const std::vector<CutlassTileConfig> allowed_configs =
98
- is_weight_only ? allowed_quant_B_configs : square_configs;
99
- return simt_configs_only ? simt_configs : allowed_configs;
99
+ return quant_B_configs;
100
100
}
101
101
102
102
static std::vector<CutlassGemmConfig> get_candidate_configs (
103
103
const int sm,
104
104
const int group_size,
105
105
const bool is_weight_only,
106
- const bool is_weight_only_encoder,
107
106
const bool simt_configs_only,
108
107
const bool is_moe) {
109
- std::vector<CutlassTileConfig> tiles =
110
- get_candidate_tiles (is_weight_only,
111
- is_weight_only_encoder,
112
- simt_configs_only,
113
- sm,
114
- group_size,
115
- is_moe);
108
+ std::vector<CutlassTileConfig> tiles = get_candidate_tiles (
109
+ is_weight_only, simt_configs_only, sm, group_size, is_moe);
116
110
117
111
std::vector<CutlassGemmConfig> candidate_configs;
118
112
const int min_stages = 2 ;
0 commit comments