Skip to content

Commit 0322d2e

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
making forward optimization work take 3 (#4864)
Summary: X-link: facebookresearch/FBGEMM#1887 Pull Request resolved: #4864 Reviewed By: Aya-ZIbra Differential Revision: D81809218 fbshipit-source-id: a6c26d5a8d3c10d3b49f2bf1a4cf017f92860400
1 parent d54fb85 commit 0322d2e

File tree

2 files changed

+180
-34
lines changed

2 files changed

+180
-34
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_fusion.hpp

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ struct NoMask {
8484
return get_trip_count(blk_coord, tile_shape, problem_size);
8585
}
8686

87+
template<class BlkCoord, class TileShape, class ProblemSize>
88+
CUTLASS_DEVICE
89+
int get_n_block_start_unmask(
90+
BlkCoord const& blk_coord,
91+
TileShape const& tile_shape,
92+
ProblemSize const& problem_size) {
93+
94+
return 0;
95+
}
96+
97+
template<class BlkCoord, class TileShape, class ProblemSize>
98+
CUTLASS_DEVICE
99+
int get_n_block_stop_unmask(
100+
BlkCoord const& blk_coord,
101+
TileShape const& tile_shape,
102+
ProblemSize const& problem_size) {
103+
104+
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
105+
}
106+
87107
template<class AccQK, class IndexQK, class ProblemSize>
88108
CUTLASS_DEVICE
89109
void apply_mask(
@@ -140,6 +160,26 @@ struct ResidualMask : NoMask {
140160
return get_trip_count(blk_coord, tile_shape, problem_size);
141161
}
142162

163+
template<class BlkCoord, class TileShape, class ProblemSize>
164+
CUTLASS_DEVICE
165+
int get_n_block_start_unmask(
166+
BlkCoord const& blk_coord,
167+
TileShape const& tile_shape,
168+
ProblemSize const& problem_size) {
169+
170+
return 0;
171+
}
172+
173+
template<class BlkCoord, class TileShape, class ProblemSize>
174+
CUTLASS_DEVICE
175+
int get_n_block_stop_unmask(
176+
BlkCoord const& blk_coord,
177+
TileShape const& tile_shape,
178+
ProblemSize const& problem_size) {
179+
180+
return get_unmasked_trip_count(blk_coord, tile_shape, problem_size);
181+
}
182+
143183
template<class AccQK, class IndexQK, class ProblemSize>
144184
CUTLASS_DEVICE
145185
void apply_mask(
@@ -293,6 +333,26 @@ struct CausalMask : NoMask {
293333
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
294334
}
295335

336+
template<class BlkCoord, class TileShape, class ProblemSize>
337+
CUTLASS_DEVICE
338+
int get_n_block_start_unmask(
339+
BlkCoord const& blk_coord,
340+
TileShape const& tile_shape,
341+
ProblemSize const& problem_size) {
342+
343+
return 0;
344+
}
345+
346+
template<class BlkCoord, class TileShape, class ProblemSize>
347+
CUTLASS_DEVICE
348+
int get_n_block_stop_unmask(
349+
BlkCoord const& blk_coord,
350+
TileShape const& tile_shape,
351+
ProblemSize const& problem_size) {
352+
353+
return get_unmasked_trip_count(blk_coord, tile_shape, problem_size);
354+
}
355+
296356
template<class AccQK, class IndexQK, class ProblemSize>
297357
CUTLASS_DEVICE
298358
void apply_mask(
@@ -456,8 +516,55 @@ struct LocalMask : NoMask {
456516
TileShape const& tile_shape,
457517
ProblemSize const& problem_size) {
458518

459-
// TODO: follow CausalMask to improve this
460-
return 0;
519+
const int n_block_start_unmask = get_n_block_start_unmask(blk_coord, tile_shape, problem_size);
520+
const int n_block_stop_unmask = get_n_block_stop_unmask(blk_coord, tile_shape, problem_size);
521+
522+
return n_block_stop_unmask - n_block_start_unmask;
523+
}
524+
525+
template<class BlkCoord, class TileShape, class ProblemSize>
526+
CUTLASS_DEVICE
527+
int get_n_block_start_unmask(
528+
BlkCoord const& blk_coord,
529+
TileShape const& tile_shape,
530+
ProblemSize const& problem_size) {
531+
// this does not guarantee to be smaller than n_block_stop_unmask
532+
533+
const int kBlockM = get<0>(tile_shape);
534+
const int kBlockN = get<1>(tile_shape);
535+
const int seq_len_k = get<1>(problem_size);
536+
537+
const int m_block = get<0>(blk_coord);
538+
const int offset_q = IsQBegin? 0 : get<1>(problem_size) - get<0>(problem_size);
539+
540+
const int m_idx_max = (m_block + 1) * kBlockM;
541+
542+
// -1 to make this inclusive
543+
const int n_idx_max_left = std::max(m_idx_max + offset_q - window_size_left - 1, 0);
544+
545+
return ceil_div(n_idx_max_left, kBlockN);
546+
}
547+
548+
template<class BlkCoord, class TileShape, class ProblemSize>
549+
CUTLASS_DEVICE
550+
int get_n_block_stop_unmask(
551+
BlkCoord const& blk_coord,
552+
TileShape const& tile_shape,
553+
ProblemSize const& problem_size) {
554+
// this does not guarantee to be larger than n_block_start_unmask
555+
556+
const int kBlockM = get<0>(tile_shape);
557+
const int kBlockN = get<1>(tile_shape);
558+
const int seq_len_k = get<1>(problem_size);
559+
560+
const int m_block = get<0>(blk_coord);
561+
const int offset_q = IsQBegin? 0 : get<1>(problem_size) - get<0>(problem_size);
562+
563+
const int m_idx_min = m_block * kBlockM;
564+
// +1 to make this exclusive
565+
const int n_idx_min_right = std::min(m_idx_min + offset_q + window_size_right + 1, seq_len_k);
566+
567+
return n_idx_min_right / kBlockN;
461568
}
462569

463570
template<class AccQK, class IndexQK, class ProblemSize>

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -729,11 +729,12 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
729729
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
730730
OrderBarrierSoftmax& order_s) {
731731

732-
int mask_tile_count = Mask(params.window_size_left, params.window_size_right).get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape);
733-
734-
auto min_max = Mask(params.window_size_left, params.window_size_right).get_n_block_min_max(blk_coord, TileShape{}, problem_shape);
732+
Mask mask(params.window_size_left, params.window_size_right);
733+
auto min_max = mask.get_n_block_min_max(blk_coord, TileShape{}, problem_shape);
735734
int n_block_min = get<0>(min_max);
736-
// int n_block_max = get<1>(min_max);
735+
const int n_block_max = get<1>(min_max);
736+
const int n_block_start_unmask = mask.get_n_block_start_unmask(blk_coord, TileShape{}, problem_shape);
737+
const int n_block_stop_unmask = mask.get_n_block_stop_unmask(blk_coord, TileShape{}, problem_shape);
737738

738739
ElementQK row_max = -INFINITY;
739740
ElementQK row_sum = 0;
@@ -747,35 +748,73 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
747748

748749
pipeline_c.producer_acquire(pipeline_c_producer_state);
749750

750-
CUTLASS_PRAGMA_NO_UNROLL
751-
for (; mask_tile_count > 0; mask_tile_count -= 1) {
752-
softmax_step<false /* need_apply_mask */>(
753-
row_max, row_sum, stage,
754-
(mask_tile_count == 1) &&
755-
(Mask(params.window_size_left, params.window_size_right).get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0),
756-
blk_coord, cS, params, problem_shape,
757-
pipeline_s, pipeline_s_consumer_state,
758-
pipeline_c, pipeline_c_producer_state,
759-
order_s
760-
);
761-
762-
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
763-
}
764-
765-
// Masked iterations
766-
mask_tile_count = Mask(params.window_size_left, params.window_size_right).get_masked_trip_count(blk_coord, TileShape{}, problem_shape);
767-
768-
CUTLASS_PRAGMA_NO_UNROLL
769-
for (; mask_tile_count > 0; mask_tile_count -= 1) {
770-
softmax_step<true /* need_apply_mask */>(
771-
row_max, row_sum, stage, mask_tile_count == 1,
772-
blk_coord, cS, params, problem_shape,
773-
pipeline_s, pipeline_s_consumer_state,
774-
pipeline_c, pipeline_c_producer_state,
775-
order_s
776-
);
751+
// from observation, dispatch is better for the mask -> unmask -> mask pattern and when the number of tiles is small
752+
if constexpr (std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask>
753+
|| std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
754+
auto dispatch_bool = [](bool b, auto fn) {
755+
if (b) {
756+
fn(cute::true_type{});
757+
}
758+
else {
759+
fn(cute::false_type{});
760+
}
761+
};
762+
763+
CUTLASS_PRAGMA_NO_UNROLL
764+
for (; n_block_min < n_block_max; n_block_min += 1) {
765+
// Apply mask only for tiles outside the attention window
766+
// for local mask, we don't guarantee n_block_start_unmask <= n_block_stop_unmask <= n_block_max
767+
bool need_apply_mask = warp_uniform(n_block_min < n_block_start_unmask || n_block_min >= n_block_stop_unmask);
768+
769+
dispatch_bool(need_apply_mask, [&](auto is_masked_tile) {
770+
if constexpr (decltype(is_masked_tile)::value) {
771+
softmax_step<true /* need_apply_mask */>(
772+
row_max, row_sum, stage, (n_block_min == n_block_max - 1),
773+
blk_coord, cS, params, problem_shape,
774+
pipeline_s, pipeline_s_consumer_state,
775+
pipeline_c, pipeline_c_producer_state,
776+
order_s
777+
);
778+
} else {
779+
softmax_step<false /* need_apply_mask */>(
780+
row_max, row_sum, stage, (n_block_min == n_block_max - 1),
781+
blk_coord, cS, params, problem_shape,
782+
pipeline_s, pipeline_s_consumer_state,
783+
pipeline_c, pipeline_c_producer_state,
784+
order_s
785+
);
786+
}
787+
});
788+
789+
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
790+
}
791+
} else {
792+
CUTLASS_PRAGMA_NO_UNROLL
793+
for (; n_block_min < n_block_stop_unmask; n_block_min += 1) {
794+
softmax_step<false /* need_apply_mask */>(
795+
row_max, row_sum, stage,
796+
(n_block_min == n_block_max - 1),
797+
blk_coord, cS, params, problem_shape,
798+
pipeline_s, pipeline_s_consumer_state,
799+
pipeline_c, pipeline_c_producer_state,
800+
order_s
801+
);
802+
803+
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
804+
}
777805

778-
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
806+
CUTLASS_PRAGMA_NO_UNROLL
807+
for (; n_block_min < n_block_max; n_block_min += 1) {
808+
softmax_step<true /* need_apply_mask */>(
809+
row_max, row_sum, stage, n_block_min == n_block_max - 1,
810+
blk_coord, cS, params, problem_shape,
811+
pipeline_s, pipeline_s_consumer_state,
812+
pipeline_c, pipeline_c_producer_state,
813+
order_s
814+
);
815+
816+
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
817+
}
779818
}
780819

781820
pipeline_c.producer_commit(pipeline_c_producer_state);

0 commit comments

Comments
 (0)