Skip to content

Commit b9af5ac

Browse files
committed
Removes SMEM-based split-kv restriction
Drops device SMEM query and the special case that forced a single split on limited-SMEM GPUs for large head dims. Simplifies the forward path and defers split selection to the caller, reducing runtime branching.
1 parent 9cc0ca6 commit b9af5ac

File tree

1 file changed

+0
-13
lines changed

1 file changed

+0
-13
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,24 +249,11 @@ void set_params_dgrad(
249249
}
250250

251251
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
252-
int device;
253-
cudaGetDevice(&device);
254-
int max_smem_per_block;
255-
cudaError status_ = cudaDeviceGetAttribute(
256-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
257-
);
258-
if (status_ != cudaSuccess) {
259-
C10_CUDA_CHECK(status_);
260-
}
261-
262252
FP16_SWITCH(!params.is_bf16, [&] {
263253
HEADDIM_SWITCH(params.d, [&] {
264254
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
265255
BOOL_SWITCH(params.has_mask, Has_mask, [&] {
266256
BOOL_SWITCH(params.has_bias, Has_bias, [&] {
267-
// splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits
268-
bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024);
269-
params.num_splits = splitkv_forbidden ? 1 : params.num_splits;
270257
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
271258
run_mha_fwd_<elem_type, kHeadDim, Is_causal, Has_mask, Has_bias>(params, stream);
272259
} else {

0 commit comments

Comments
 (0)