diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..dfb007e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,69 @@ +name: Bug report +description: Create a report to help us improve Flash-DMA +title: "[BUG REPORT] " +labels: + - bug +assignees: + - LoserCheems + - Evanwu1125 + - SNHuan + - Thanksyy + - ftgreat + - zacliu2023 + - juliohsu + - wubingheng111 +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to report an issue. Please fill out the details below so we can reproduce and fix the problem quickly. + - type: textarea + id: bug-description + attributes: + label: Describe the bug + description: Provide a concise description of the incorrect behaviour. + placeholder: Unexpected error when calling flash_dmattn(...) + validations: + required: true + - type: textarea + id: reproduction + attributes: + label: Steps to reproduce + description: Share the minimal steps or code necessary for us to see the failure. + placeholder: | + 1. Import flash_dmattn + 2. Run the snippet below + 3. Observe the error + render: python + validations: + required: true + - type: textarea + id: expected-behavior + attributes: + label: Expected behaviour + description: Tell us what you expected to happen instead. + placeholder: The kernel should return valid attention output without raising an exception. + validations: + required: true + - type: textarea + id: environment + attributes: + label: Environment information + description: Run the following command and paste the full output. + placeholder: | + python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name() if torch.cuda.is_available() else \"None\"}')" + render: shell + validations: + required: true + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Include sequence lengths, batch sizes, or any other details that might help us debug. + placeholder: Tested with seq_len=8192, batch=2, head_dim=128... + - type: textarea + id: traceback + attributes: + label: Error traceback + description: Paste the full traceback if available. + render: text diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..46a7d39 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,64 @@ +name: Feature request +description: Suggest an idea for FDMA +title: "[FEATURE REQUEST] " +labels: + - feature +assignees: + - LoserCheems + - Evanwu1125 + - SNHuan + - Thanksyy + - ftgreat + - zacliu2023 + - juliohsu + - wubingheng111 +body: + - type: markdown + attributes: + value: | + Help us understand the feature you are proposing and why it matters for Flash-DMA workflows. + - type: textarea + id: problem + attributes: + label: Problem statement + description: Explain the problem or limitation that motivates this feature request. + placeholder: I am limited by... + validations: + required: true + - type: textarea + id: proposed-solution + attributes: + label: Proposed solution + description: Describe the feature or behaviour you would like to see. + placeholder: Introduce a kernel path that... + validations: + required: true + - type: textarea + id: alternatives + attributes: + label: Alternatives considered + description: List any other approaches you have evaluated and why they are insufficient. + placeholder: I tried using... + - type: textarea + id: implementation + attributes: + label: Implementation details + description: Call out potential CUDA/Python changes, performance implications, or compatibility considerations. + placeholder: Requires updates to flash_dmattn_interface and CUDA op... + - type: textarea + id: use-case + attributes: + label: Use case + description: Describe the workloads or scenarios that would benefit from this feature. + placeholder: Long-context code completion with... + - type: textarea + id: references + attributes: + label: Related work + description: Share links to papers, repositories, or prior art that inspired this request. + placeholder: Paper link or repository URL + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any extra information or screenshots that may help us understand the request. diff --git a/.github/ISSUE_TEMPLATE/performance_issue.yml b/.github/ISSUE_TEMPLATE/performance_issue.yml new file mode 100644 index 0000000..c3ecca4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/performance_issue.yml @@ -0,0 +1,79 @@ +name: Performance issue +description: Report performance problems or optimisation opportunities +title: "[PERFORMANCE] " +labels: + - performance +assignees: + - LoserCheems + - Evanwu1125 + - SNHuan + - Thanksyy + - ftgreat + - zacliu2023 + - juliohsu + - wubingheng111 +body: + - type: markdown + attributes: + value: | + Provide enough detail about performance regressions or optimization opportunities so we can reproduce and diagnose them. + - type: textarea + id: issue-description + attributes: + label: Performance issue description + description: Summarise the performance problem. + placeholder: Forward latency increases when... + validations: + required: true + - type: textarea + id: current-performance + attributes: + label: Current performance metrics + description: Share benchmark numbers and configuration (sequence length, batch size, heads, head dimension, throughput, memory usage). + placeholder: | + Sequence length: 8192 + Batch size: 2 + Heads: 32 + Head dim: 128 + Speed: 15.2 ms/iteration + Memory: 8.5 GB + validations: + required: true + - type: textarea + id: expected-performance + attributes: + label: Expected performance + description: Explain what performance you expect and the baseline you are comparing against. + placeholder: Expect <10 ms/iteration based on Flash Attention benchmark... + - type: textarea + id: environment + attributes: + label: Environment information + description: Run the following command and paste the output. + placeholder: | + python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name() if torch.cuda.is_available() else \"None\"}')" + render: shell + validations: + required: true + - type: textarea + id: benchmark-code + attributes: + label: Benchmark code + description: Provide the code snippet or script used to measure performance. + render: python + - type: textarea + id: profiling + attributes: + label: Profiling information + description: Include relevant excerpts from nsys, nvprof, or PyTorch profiler if available. + - type: textarea + id: system-info + attributes: + label: System information + description: GPU model, compute capability, CPU, RAM, and other hardware details. + placeholder: RTX 4090 24GB, compute capability 8.9, Intel i9-14900K, 64GB RAM + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Mention regressions, different batch sizes, attention patterns, or other observations. diff --git a/.github/PULL_REQUEST_TEMPLATE/bug_fix.yml b/.github/PULL_REQUEST_TEMPLATE/bug_fix.yml new file mode 100644 index 0000000..766415f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/bug_fix.yml @@ -0,0 +1,60 @@ +name: Bug Fix +description: Fix a bug with clear reproduction, scope, and tests +title: "[BUG FIX] " +labels: + - bug +body: + - type: markdown + attributes: + value: | + Thanks for contributing a bug fix! Please complete the sections below so reviewers can understand and verify the change quickly. + - type: textarea + id: summary + attributes: + label: Summary + description: What bug is fixed and what parts of the codebase are impacted? + placeholder: Resolves crash when... + validations: + required: true + - type: textarea + id: root-cause + attributes: + label: Root cause + description: Briefly describe the underlying issue. + placeholder: The kernel assumed... + - type: textarea + id: changes + attributes: + label: Changes + description: Highlight the notable code-level modifications. + placeholder: Updated flash_dmattn_interface to... + validations: + required: true + - type: textarea + id: reproduction + attributes: + label: Reproduction steps or MRE + description: Provide steps or a minimal snippet that reproduces the original bug. + render: python + - type: textarea + id: tests + attributes: + label: Tests + description: List the tests you added or ran and their results. + placeholder: Ran benchmarks/forward_equivalence.py; added unit test... + validations: + required: true + - type: textarea + id: compatibility + attributes: + label: Compatibility + description: Note any migration concerns or backwards compatibility considerations. + - type: checkboxes + id: checklist + attributes: + label: Checklist + options: + - label: Linked issue provided + - label: Adds or updates tests + - label: Updates docs if needed + - label: No performance regressions observed diff --git a/.github/PULL_REQUEST_TEMPLATE/feature_support.yml b/.github/PULL_REQUEST_TEMPLATE/feature_support.yml new file mode 100644 index 0000000..46bf35d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/feature_support.yml @@ -0,0 +1,60 @@ +name: Feature Support +description: Introduce a new feature with design context and tests +title: "[FEATURE SUPPORT] " +labels: + - feature +body: + - type: markdown + attributes: + value: | + Share enough detail about the new feature so reviewers can evaluate scope, design, and testing. + - type: textarea + id: summary + attributes: + label: Summary + description: What feature is being added and why? + placeholder: Adds configurable... + validations: + required: true + - type: textarea + id: design + attributes: + label: Design + description: Outline the design or architecture and mention alternatives considered. + placeholder: Uses new backend selection flow... + - type: textarea + id: changes + attributes: + label: Changes + description: Describe new or changed public APIs, configuration, or CLI behaviour. + placeholder: Adds flash_dmattn.feature_flag... + validations: + required: true + - type: textarea + id: implementation-notes + attributes: + label: Implementation notes + description: Highlight tricky parts or noteworthy implementation details. + - type: textarea + id: tests + attributes: + label: Tests + description: List unit or integration tests you added or updated and how you validated them. + placeholder: Ran benchmarks/forward_equivalence.py; added example in... + validations: + required: true + - type: textarea + id: docs + attributes: + label: Documentation + description: Mention doc updates or examples that accompany this feature. + - type: checkboxes + id: checklist + attributes: + label: Checklist + options: + - label: Linked issue provided + - label: API stabilised + - label: Tests added or updated + - label: Docs added or updated + - label: No known performance regressions diff --git a/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml b/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml new file mode 100644 index 0000000..5a00c15 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml @@ -0,0 +1,61 @@ +name: Performance Optimization +description: Optimize performance with benchmark evidence +title: "[PERFORMANCE OPTIMIZATION] " +labels: + - performance +body: + - type: markdown + attributes: + value: | + Document the optimisation, methodology, and results so reviewers can validate gains and correctness. + - type: textarea + id: summary + attributes: + label: Summary + description: What is optimized and why? + placeholder: Improves forward latency for... + validations: + required: true + - type: textarea + id: baseline + attributes: + label: Baseline metrics + description: Provide the current performance numbers and environment. + placeholder: Baseline throughput 150 tok/s on H100 with... + validations: + required: true + - type: textarea + id: approach + attributes: + label: Approach + description: Describe the optimization techniques used. + placeholder: Introduced block-wise accumulation... + - type: textarea + id: results + attributes: + label: Results + description: Share before/after benchmarks and how to reproduce them. + placeholder: | + Before: 15.2 ms/iteration (benchmark command) + After: 9.8 ms/iteration (benchmark command) + validations: + required: true + - type: textarea + id: impact + attributes: + label: Impact + description: Note memory, throughput trade-offs, or hardware-specific considerations. + - type: textarea + id: risks + attributes: + label: Risks + description: Highlight edge cases, correctness risks, or gating tests added. + - type: checkboxes + id: checklist + attributes: + label: Checklist + options: + - label: Linked issue provided + - label: Benchmarks included and reproducible + - label: No accuracy regression + - label: Docs updated where needed diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index dd52ac7..d26ecde 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -553,6 +553,7 @@ mha_fwd( return {out, softmax_lse, p}; } + // std::vector // mha_varlen_fwd( // at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -774,19 +775,19 @@ mha_fwd( // return {out, softmax_lse, p}; // } -// void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { -// FP16_SWITCH(!params.is_bf16, [&] { -// HEADDIM_SWITCH(params.d, [&] { -// BOOL_SWITCH(params.is_causal, Is_causal, [&] { -// BOOL_SWITCH(params.has_mask, Has_mask, [&] { -// BOOL_SWITCH(params.has_bias, Has_bias, [&] { -// run_mha_bwd_(params, stream); -// }); -// }); -// }); -// }); -// }); -// } +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.has_mask, Has_mask, [&] { + BOOL_SWITCH(params.has_bias, Has_bias, [&] { + run_mha_bwd_(params, stream); + }); + }); + }); + }); + }); +} std::vector mha_bwd(