Skip to content

Commit 2b753a5

Browse files
authored
unittest: remove debug-print jit examples from unittest (#1851)
<!-- .github/pull_request_template.md --> ## 📌 Description The debug print statements in `test_jit_examples` unittests clutter the CI output, making it difficult to identify useful information. This PR removes them from the unittests. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 989e82c commit 2b753a5

File tree

2 files changed

+1
-136
lines changed

2 files changed

+1
-136
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode
117117

118118
## Custom Attention Variants
119119

120-
Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_jit_example.py).
120+
Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py).
121121

122122
## C++ API and TVM Bindings
123123

tests/utils/test_jit_example.py

Lines changed: 0 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -572,145 +572,10 @@ def test_batch_prefill_sm90_flash_sigmoid():
572572
torch.testing.assert_close(o_paged, o_ref, rtol=2e-2, atol=2e-2)
573573

574574

575-
def test_debug_print_logits():
576-
torch.manual_seed(42)
577-
variant_decl = r"""
578-
struct DebugPrintLogits : AttentionVariantBase {
579-
static constexpr bool use_softmax = true;
580-
581-
uint32_t window_left, qo_len, kv_len;
582-
float sm_scale_log2;
583-
584-
// Create closure
585-
template <typename Params>
586-
__device__ __host__ DebugPrintLogits(const Params& params, uint32_t batch_idx,
587-
uint8_t* smem_ptr) {
588-
qo_len = params.get_qo_len(batch_idx);
589-
kv_len = params.get_kv_len(batch_idx);
590-
window_left = kv_len;
591-
sm_scale_log2 = params.sm_scale * math::log2e;
592-
}
593-
594-
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
595-
if (logits >= 5) {
596-
printf("Large logits at qo_idx=%d, kv_idx=%d, qo_head_idx=%d, kv_head_idx=%d: %.3f\n",
597-
qo_idx, kv_idx, qo_head_idx, kv_head_idx, float(logits));
598-
}
599-
return logits;
600-
});
601-
};
602-
"""
603-
jit_module = gen_customize_single_prefill_module(
604-
"fa2", # backend
605-
"batch_prefill_debug_print_logits", # uri
606-
torch.float16, # dtype_q
607-
torch.float16, # dtype_kv
608-
torch.float16, # dtype_o
609-
128, # hidden_dim_qk
610-
128, # hidden_dim_vo
611-
[], # additional_tensor_names
612-
[], # additional_tensor_dtypes
613-
["sm_scale"], # additional_scalar_names
614-
["double"], # additional_scalar_dtypes
615-
"DebugPrintLogits",
616-
variant_decl,
617-
).build_and_load()
618-
619-
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
620-
621-
q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda")
622-
k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
623-
v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
624-
sm_scale = 1.0 / math.sqrt(128)
625-
o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value)
626-
627-
p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale
628-
o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()
629-
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
630-
631-
632-
def test_sm90_debug_print_logits():
633-
if not is_sm90a_supported(torch.device("cuda")):
634-
pytest.skip("SM90A is not supported")
635-
636-
torch.manual_seed(42)
637-
variant_decl = r"""
638-
struct DebugPrintLogits : AttentionVariantBase {
639-
float sm_scale_log2;
640-
int qo_len, kv_len;
641-
642-
// Init
643-
template <typename MainloopParams, typename BlockCoord>
644-
__device__ __host__ DebugPrintLogits(const MainloopParams& params, const BlockCoord& block_coord) {
645-
sm_scale_log2 = params.additional_params.sm_scale * math::log2e;
646-
auto [_, __, ___, ____, _____, qo_len_, kv_len_, batch_idx] =
647-
block_coord;
648-
649-
qo_len = qo_len_;
650-
kv_len = kv_len_;
651-
}
652-
653-
654-
template <int NUM_ROWS_PER_THREAD>
655-
__device__ auto GetAttentionUpdater() {
656-
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/false>(sm_scale_log2);
657-
}
658-
659-
660-
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
661-
if (qo_idx < qo_len && kv_idx < kv_len) {
662-
printf(
663-
"---> LOGITS DEBUG: "
664-
"qo_idx=%-5d "
665-
"kv_idx=%-5d "
666-
"sm_scale_log2=%-12.5f "
667-
"logits=%-12.5f "
668-
"\n",
669-
qo_idx,
670-
kv_idx,
671-
sm_scale_log2,
672-
static_cast<float>(logits));
673-
}
674-
logits *= sm_scale_log2;
675-
return logits;
676-
})
677-
};
678-
"""
679-
jit_module = gen_customize_single_prefill_module(
680-
"fa3", # backend
681-
"debug_print_logits", # uri
682-
torch.float16, # dtype_q
683-
torch.float16, # dtype_kv
684-
torch.float16, # dtype_o
685-
128, # hidden_dim_qk
686-
128, # hidden_dim_vo
687-
[], # additional_tensor_names
688-
[], # additional_tensor_dtypes
689-
["sm_scale"], # additional_scalar_names
690-
["double"], # additional_scalar_dtypes
691-
"DebugPrintLogits",
692-
variant_decl,
693-
).build_and_load()
694-
695-
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
696-
697-
q = torch.randn(16, 2, 128, dtype=torch.float16, device="cuda")
698-
k = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda")
699-
v = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda")
700-
sm_scale = 1.0 / math.sqrt(128)
701-
o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value)
702-
703-
p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale
704-
o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()
705-
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
706-
707-
708575
if __name__ == "__main__":
709576
test_single_decode_mask()
710577
test_flash_sigmoid()
711578
test_dump_logits()
712-
test_debug_print_logits()
713-
test_sm90_debug_print_logits()
714579
test_batch_decode_flash_sigmoid(False)
715580
test_batch_decode_flash_sigmoid(True)
716581
test_batch_prefill_flash_sigmoid()

0 commit comments

Comments
 (0)