diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 8130be6..da66343 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -146,7 +146,6 @@ def dynamic_mask_attention_python( A: torch.Tensor, scaling: float, cache_position: torch.Tensor, - dout: torch.Tensor, keep_window_size=2048, is_causal=True, ): @@ -161,7 +160,6 @@ def dynamic_mask_attention_python( A: [num_kv_heads] scaling: Attention scaling factor cache_position: Cache position for causal masking - dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking @@ -201,7 +199,7 @@ def dynamic_mask_attention_python( attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad @@ -214,7 +212,6 @@ def dynamic_mask_attention_cuda( A: torch.Tensor, scaling: float, cache_position: torch.Tensor, - dout: torch.Tensor, keep_window_size=2048, is_causal=True, ): @@ -229,7 +226,6 @@ def dynamic_mask_attention_cuda( A: [num_kv_heads] scaling: Attention scaling factor cache_position: Cache position for causal masking - dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking @@ -263,7 +259,7 @@ def dynamic_mask_attention_cuda( value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] # Call the flash_dmattn_func interface - attn_outputs, softmax_lse, S_dmask = flash_dmattn_func( + attn_outputs = flash_dmattn_func( query=query_states, # q: [batch, query_len, num_heads, head_dim] key=key_states, # k: [batch, key_len, num_kv_heads, head_dim] value=value_states, # v: [batch, key_len, num_kv_heads, head_dim] @@ -272,12 +268,12 @@ def dynamic_mask_attention_cuda( is_causal=is_causal, # causal masking scale=scaling, # scaling factor softcap=0.0, - deterministic=True, - return_attn_probs=True + deterministic=False, + return_attn_probs=False ) # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad @@ -290,7 +286,6 @@ def dynamic_mask_attention_triton( A: torch.Tensor, scaling: float, cache_position: torch.Tensor, - dout: torch.Tensor, keep_window_size=2048, is_causal=True, ): @@ -305,7 +300,6 @@ def dynamic_mask_attention_triton( A: [num_kv_heads] scaling: Attention scaling factor cache_position: Cache position for causal masking - dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking @@ -361,7 +355,7 @@ def dynamic_mask_attention_triton( ) # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad @@ -374,7 +368,6 @@ def dynamic_mask_attention_flex( A: torch.Tensor, scaling: float, cache_position: torch.Tensor, - dout: torch.Tensor, keep_window_size=2048, is_causal=True, ): @@ -389,7 +382,6 @@ def dynamic_mask_attention_flex( A: [num_kv_heads] scaling: Attention scaling factor cache_position: Cache position for causal masking - dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking @@ -436,7 +428,7 @@ def dynamic_mask_attention_flex( ) # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad @@ -552,76 +544,90 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) test_configs = [ # Head dim 32 - (1, 2, 1, 128, 128, 32, True), (1, 2, 1, 128, 128, 32, False), - (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 128, 128, 32, True), (1, 2, 1, 256, 256, 32, False), - (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 256, 256, 32, True), (1, 2, 1, 512, 512, 32, False), - (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 512, 512, 32, True), (1, 2, 1, 1024, 1024, 32, False), - (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 1024, 1024, 32, True), (1, 2, 1, 2048, 2048, 32, False), - (1, 2, 1, 4096, 4096, 32, True), # some INF in dbias, Idk why + (1, 2, 1, 2048, 2048, 32, True), (1, 2, 1, 4096, 4096, 32, False), + (1, 2, 1, 4096, 4096, 32, True), # Head dim 64 - (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 128, 128, 64, False), - (1, 2, 1, 256, 256, 64, True), # some INF in dbias, Idk why + (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 256, 256, 64, False), - (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 256, 256, 64, True), (1, 2, 1, 512, 512, 64, False), - (1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why + (1, 2, 1, 512, 512, 64, True), (1, 2, 1, 1024, 1024, 64, False), - (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 1024, 1024, 64, True), (1, 2, 1, 2048, 2048, 64, False), - (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 2048, 2048, 64, True), (1, 2, 1, 4096, 4096, 64, False), + (1, 2, 1, 4096, 4096, 64, True), # Head dim 96 - (1, 2, 1, 128, 128, 96, True), (1, 2, 1, 128, 128, 96, False), - (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 128, 128, 96, True), (1, 2, 1, 256, 256, 96, False), - (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 256, 256, 96, True), (1, 2, 1, 512, 512, 96, False), - (1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why + (1, 2, 1, 512, 512, 96, True), (1, 2, 1, 1024, 1024, 96, False), - (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 1024, 1024, 96, True), (1, 2, 1, 2048, 2048, 96, False), - (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 2048, 2048, 96, True), (1, 2, 1, 4096, 4096, 96, False), + (1, 2, 1, 4096, 4096, 96, True), # Head dim 128 - (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, False), - (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, False), - (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 256, 256, 128, True), (1, 2, 1, 512, 512, 128, False), - (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 512, 512, 128, True), (1, 2, 1, 1024, 1024, 128, False), - (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 1024, 1024, 128, True), (1, 2, 1, 2048, 2048, 128, False), - (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 2048, 2048, 128, True), (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + + # # Head dim 192 + # Not enough shared memory for head_dim=192 in bwd yet + # (1, 2, 1, 128, 128, 192, False), + # (1, 2, 1, 128, 128, 192, True), + # (1, 2, 1, 256, 256, 192, False), + # (1, 2, 1, 256, 256, 192, True), + # (1, 2, 1, 512, 512, 192, False), + # (1, 2, 1, 512, 512, 192, True), + # (1, 2, 1, 1024, 1024, 192, False), + # (1, 2, 1, 1024, 1024, 192, True), + # (1, 2, 1, 2048, 2048, 192, False), + # (1, 2, 1, 2048, 2048, 192, True), + # (1, 2, 1, 4096, 4096, 192, False), + # (1, 2, 1, 4096, 4096, 192, True), # Head dim 256 - # Because fwd uses splitkv branch, this branch does not support head_dim=256 for now - # For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory - # (1, 2, 1, 128, 128, 256, True), + # Not enough shared memory for head_dim=256 in bwd yet # (1, 2, 1, 128, 128, 256, False), - # (1, 2, 1, 256, 256, 256, True), + # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 256, 256, 256, False), - # (1, 2, 1, 512, 512, 256, True), + # (1, 2, 1, 256, 256, 256, True), # (1, 2, 1, 512, 512, 256, False), - # (1, 2, 1, 1024, 1024, 256, True), + # (1, 2, 1, 512, 512, 256, True), # (1, 2, 1, 1024, 1024, 256, False), - # (1, 2, 1, 2048, 2048, 256, True), + # (1, 2, 1, 1024, 1024, 256, True), # (1, 2, 1, 2048, 2048, 256, False), - # (1, 2, 1, 4096, 4096, 256, True), + # (1, 2, 1, 2048, 2048, 256, True), # (1, 2, 1, 4096, 4096, 256, False), + # (1, 2, 1, 4096, 4096, 256, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -673,13 +679,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 - - # Create gradient for output - dout = torch.randn( - batch_size, query_len, num_heads, head_dim, - device=device, dtype=dtype - ) + keep_window_size = 1024 # Clone inputs for Python implementation query_python = query_states.clone().detach().requires_grad_(True) @@ -692,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): start_time = time.time() attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python( query_python, key_python, value_python, dt_proj_python, A_python, - scaling, cache_position, dout.clone(), keep_window_size, is_causal + scaling, cache_position, keep_window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time @@ -709,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): start_time = time.time() attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda( query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda, - scaling, cache_position, dout.clone(), keep_window_size, is_causal + scaling, cache_position, keep_window_size, is_causal ) torch.cuda.synchronize() cuda_time = time.time() - start_time @@ -774,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): if not is_close and max_dbias_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - del query_states, key_states, value_states, dt_proj, A, cache_position, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda + del query_states, key_states, value_states, dt_proj, A, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda torch.cuda.empty_cache() gc.collect() torch.cuda.synchronize() diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index b1008b7..c08600b 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -177,13 +177,6 @@ def scaled_dot_product_attention_backward( value_states = value_states.contiguous() try: - # Create gradient for output - batch_size, num_heads, query_len, head_dim = query_states.shape - dout = torch.randn( - batch_size, query_len, num_heads, head_dim, - device=query_states.device, dtype=query_states.dtype - ) - # Forward pass - SDPA expects q, k, v in [batch, num_heads, seq_len, head_dim] format attn_outputs = F.scaled_dot_product_attention( query_states, # [batch, num_heads, query_len, head_dim] @@ -201,8 +194,8 @@ def scaled_dot_product_attention_backward( start_time = time.time() # Backward pass - attn_outputs.backward(dout) - + attn_outputs.sum().backward() + torch.cuda.synchronize() end_time = time.time() @@ -261,13 +254,6 @@ def dynamic_mask_attention_backward_cuda( value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] try: - # Create gradient for output - batch_size, query_len, num_heads, head_dim = query_states.shape - dout = torch.randn( - batch_size, query_len, num_heads, head_dim, - device=query_states.device, dtype=query_states.dtype - ) - # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( query=query_states, # q: [batch, query_len, num_heads, head_dim] @@ -286,7 +272,7 @@ def dynamic_mask_attention_backward_cuda( start_time = time.time() # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() torch.cuda.synchronize() end_time = time.time() @@ -356,13 +342,6 @@ def dynamic_mask_attention_backward_triton( value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - - # Create gradient for output - batch_size, query_len, num_heads, head_dim = query_states.shape - dout = torch.randn( - batch_size, query_len, num_heads, head_dim, - device=query_states.device, dtype=query_states.dtype - ) # Call the Triton implementation attn_outputs = triton_dmattn_func( @@ -379,7 +358,7 @@ def dynamic_mask_attention_backward_triton( start_time = time.time() # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() torch.cuda.synchronize() end_time = time.time() @@ -445,13 +424,6 @@ def dynamic_mask_attention_backward_flex( # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - - # Create gradient for output - batch_size, query_len, head_dim = query_states.shape[0], query_states.shape[2], query_states.shape[3] - dout = torch.randn( - batch_size, query_len, num_heads, head_dim, - device=query_states.device, dtype=query_states.dtype - ) # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( @@ -468,7 +440,7 @@ def dynamic_mask_attention_backward_flex( start_time = time.time() # Backward pass - attn_outputs.backward(dout) + attn_outputs.sum().backward() torch.cuda.synchronize() end_time = time.time() diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 02cc262..97e80f9 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -515,76 +515,88 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) test_configs = [ # Head dim 32 - (1, 2, 1, 128, 128, 32, True), (1, 2, 1, 128, 128, 32, False), - (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 128, 128, 32, True), (1, 2, 1, 256, 256, 32, False), - (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 256, 256, 32, True), (1, 2, 1, 512, 512, 32, False), - (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 512, 512, 32, True), (1, 2, 1, 1024, 1024, 32, False), - (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 1024, 1024, 32, True), (1, 2, 1, 2048, 2048, 32, False), - (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 2048, 2048, 32, True), (1, 2, 1, 4096, 4096, 32, False), + (1, 2, 1, 4096, 4096, 32, True), # Head dim 64 - (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 128, 128, 64, False), - (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 256, 256, 64, False), - (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 256, 256, 64, True), (1, 2, 1, 512, 512, 64, False), - (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 512, 512, 64, True), (1, 2, 1, 1024, 1024, 64, False), - (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 1024, 1024, 64, True), (1, 2, 1, 2048, 2048, 64, False), - (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 2048, 2048, 64, True), (1, 2, 1, 4096, 4096, 64, False), + (1, 2, 1, 4096, 4096, 64, True), # Head dim 96 - (1, 2, 1, 128, 128, 96, True), (1, 2, 1, 128, 128, 96, False), - (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 128, 128, 96, True), (1, 2, 1, 256, 256, 96, False), - (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 256, 256, 96, True), (1, 2, 1, 512, 512, 96, False), - (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 512, 512, 96, True), (1, 2, 1, 1024, 1024, 96, False), - (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 1024, 1024, 96, True), (1, 2, 1, 2048, 2048, 96, False), - (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 2048, 2048, 96, True), (1, 2, 1, 4096, 4096, 96, False), + (1, 2, 1, 4096, 4096, 96, True), # Head dim 128 - (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, False), - (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, False), - (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 256, 256, 128, True), (1, 2, 1, 512, 512, 128, False), - (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 512, 512, 128, True), (1, 2, 1, 1024, 1024, 128, False), - (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 1024, 1024, 128, True), (1, 2, 1, 2048, 2048, 128, False), - (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 2048, 2048, 128, True), (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + + # Head dim 192 + (1, 2, 1, 128, 128, 192, False), + (1, 2, 1, 128, 128, 192, True), + (1, 2, 1, 256, 256, 192, False), + (1, 2, 1, 256, 256, 192, True), + (1, 2, 1, 512, 512, 192, False), + (1, 2, 1, 512, 512, 192, True), + (1, 2, 1, 1024, 1024, 192, False), + (1, 2, 1, 1024, 1024, 192, True), + (1, 2, 1, 2048, 2048, 192, False), + (1, 2, 1, 2048, 2048, 192, True), + (1, 2, 1, 4096, 4096, 192, False), + (1, 2, 1, 4096, 4096, 192, True), - # Not support head_dim = 256 in sm89 yet - # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89 # Head dim 256 - # (1, 2, 1, 128, 128, 256, True), - # (1, 2, 1, 128, 128, 256, False), - # (1, 2, 1, 256, 256, 256, True), - # (1, 2, 1, 256, 256, 256, False), - # (1, 2, 1, 512, 512, 256, True), - # (1, 2, 1, 512, 512, 256, False), - # (1, 2, 1, 1024, 1024, 256, True), - # (1, 2, 1, 1024, 1024, 256, False), - # (1, 2, 1, 2048, 2048, 256, True), - # (1, 2, 1, 2048, 2048, 256, False), - # (1, 2, 1, 4096, 4096, 256, True), - # (1, 2, 1, 4096, 4096, 256, False), + (1, 2, 1, 128, 128, 256, False), + (1, 2, 1, 128, 128, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 4096, 4096, 256, False), + (1, 2, 1, 4096, 4096, 256, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -635,7 +647,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 + keep_window_size = 1024 # Run Python implementation start_time = time.time() diff --git a/csrc/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp similarity index 100% rename from csrc/flash_api.cpp rename to csrc/flash_dmattn/flash_api.cpp diff --git a/csrc/src/block_info.h b/csrc/flash_dmattn/src/block_info.h similarity index 100% rename from csrc/src/block_info.h rename to csrc/flash_dmattn/src/block_info.h diff --git a/csrc/src/flash.h b/csrc/flash_dmattn/src/flash.h similarity index 100% rename from csrc/src/flash.h rename to csrc/flash_dmattn/src/flash.h diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h similarity index 100% rename from csrc/src/flash_bwd_kernel.h rename to csrc/flash_dmattn/src/flash_bwd_kernel.h diff --git a/csrc/src/flash_bwd_launch_template.h b/csrc/flash_dmattn/src/flash_bwd_launch_template.h similarity index 96% rename from csrc/src/flash_bwd_launch_template.h rename to csrc/flash_dmattn/src/flash_bwd_launch_template.h index b16a957..06a46fd 100644 --- a/csrc/src/flash_bwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_bwd_launch_template.h @@ -141,8 +141,8 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { // 104KB, 1 CTAs in A100, 2 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 96KB, 2 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -166,8 +166,8 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 72KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + // 88KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times } @@ -187,8 +187,8 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { // 116KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 92KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + // 76KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -207,8 +207,8 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { // 144KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 88KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + // 80KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } diff --git a/csrc/src/flash_bwd_preprocess_kernel.h b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h similarity index 100% rename from csrc/src/flash_bwd_preprocess_kernel.h rename to csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h similarity index 99% rename from csrc/src/flash_fwd_kernel.h rename to csrc/flash_dmattn/src/flash_fwd_kernel.h index 77a5a19..c98e3b1 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -395,8 +395,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; - // Init dynamic mask processor - FLASH_NAMESPACE::Mask mask( + // Init mask processor + FLASH_NAMESPACE::Mask mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q ); @@ -1044,8 +1044,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; - // Init dynamic mask processor - FLASH_NAMESPACE::Mask mask( + // Init mask processor + FLASH_NAMESPACE::Mask mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q ); diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/flash_dmattn/src/flash_fwd_launch_template.h similarity index 100% rename from csrc/src/flash_fwd_launch_template.h rename to csrc/flash_dmattn/src/flash_fwd_launch_template.h diff --git a/csrc/src/generate_kernels.py b/csrc/flash_dmattn/src/generate_kernels.py similarity index 100% rename from csrc/src/generate_kernels.py rename to csrc/flash_dmattn/src/generate_kernels.py diff --git a/csrc/src/hardware_info.h b/csrc/flash_dmattn/src/hardware_info.h similarity index 100% rename from csrc/src/hardware_info.h rename to csrc/flash_dmattn/src/hardware_info.h diff --git a/csrc/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu diff --git a/csrc/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu rename to csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu diff --git a/csrc/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h similarity index 100% rename from csrc/src/kernel_traits.h rename to csrc/flash_dmattn/src/kernel_traits.h diff --git a/csrc/src/mask.h b/csrc/flash_dmattn/src/mask.h similarity index 99% rename from csrc/src/mask.h rename to csrc/flash_dmattn/src/mask.h index f24109a..ca4f43b 100644 --- a/csrc/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -54,7 +54,6 @@ __forceinline__ __device__ void apply_mask( } } -template struct Mask { const int max_seqlen_k, max_seqlen_q; diff --git a/csrc/src/namespace_config.h b/csrc/flash_dmattn/src/namespace_config.h similarity index 100% rename from csrc/src/namespace_config.h rename to csrc/flash_dmattn/src/namespace_config.h diff --git a/csrc/src/softmax.h b/csrc/flash_dmattn/src/softmax.h similarity index 100% rename from csrc/src/softmax.h rename to csrc/flash_dmattn/src/softmax.h diff --git a/csrc/src/static_switch.h b/csrc/flash_dmattn/src/static_switch.h similarity index 100% rename from csrc/src/static_switch.h rename to csrc/flash_dmattn/src/static_switch.h diff --git a/csrc/src/utils.h b/csrc/flash_dmattn/src/utils.h similarity index 100% rename from csrc/src/utils.h rename to csrc/flash_dmattn/src/utils.h diff --git a/setup.py b/setup.py index 5173780..32e0936 100644 --- a/setup.py +++ b/setup.py @@ -207,93 +207,93 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="flash_dmattn_cuda", sources=[ - "csrc/flash_api.cpp", + "csrc/flash_dmattn/flash_api.cpp", # Forward kernels - regular - "csrc/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu", # Forward kernels - causal - "csrc/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu", # Forward kernels - split - "csrc/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu", # Forward kernels - split causal - "csrc/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu", # Backward kernels - regular - "csrc/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu", # Backward kernels - causal - "csrc/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu", - "csrc/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu", + "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ - Path(this_dir) / "csrc", - Path(this_dir) / "csrc" / "src", + Path(this_dir) / "csrc" / "flash_dmattn", + Path(this_dir) / "csrc" / "flash_dmattn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], )