diff --git a/tests/cli/cli_cmd_helpers.py b/tests/cli/cli_cmd_helpers.py new file mode 100644 index 0000000000..b6407aaaa2 --- /dev/null +++ b/tests/cli/cli_cmd_helpers.py @@ -0,0 +1,13 @@ +from click.testing import CliRunner + +from flashinfer.__main__ import cli + + +def _test_cmd_helper(cmd: [str]): + """ + Helper for command tests + """ + runner = CliRunner() + result = runner.invoke(cli, cmd) + assert result.exit_code == 0, result.output + return result.output diff --git a/tests/cli/test_cli_cmds.py b/tests/cli/test_cli_cmds.py new file mode 100644 index 0000000000..a6a3c74f1b --- /dev/null +++ b/tests/cli/test_cli_cmds.py @@ -0,0 +1,127 @@ +""" +Test that the CLI commands work as expected. + +In general there can be two types of tests for each command: +- Real tests (with suffix `_real`) that invoke the commands without any mocking +- Mocked tests (with suffix `_mocked`) that use monkeypatch to mock logic that would + otherwise be slow (e.g. downloading cubins, filesystem calls, etc), and also to + create deterministic state so we can check for expected output (e.g. number of cubins) + +These tests don't require a GPU. CLI tests that require a GPU are in test_cli_cmds_gpu.py. +""" + +from cli_cmd_helpers import _test_cmd_helper +from flashinfer.artifacts import ArtifactPath + + +def test_show_config_cmd_real(): + """ + Test that show-config command works as expected + """ + out = _test_cmd_helper(["show-config"]) + + # Basic sections present + assert "=== Torch Version Info ===" in out + assert "=== Environment Variables ===" in out + assert "=== Artifact Path ===" in out + assert "=== Downloaded Cubins ===" in out + + +def test_show_config_cmd_mocked(monkeypatch): + """ + Test that show-config command works as but with mocked cubin status + """ + # Don't check filesystem for cubins + monkeypatch.setattr( + "flashinfer.__main__.get_artifacts_status", + lambda: (("foo.cubin", True), ("bar.cubin", False)), + ) + # Avoid module registration/inspection + monkeypatch.setattr( + "flashinfer.__main__._ensure_modules_registered", + lambda: [], + ) + + out = _test_cmd_helper(["show-config"]) + + # Uses our monkeypatched data + assert "Downloaded 1/2 cubins" in out + + +def test_cli_group_help_real(): + """ + Test that the CLI group runs without error and sanity checks the output + """ + out = _test_cmd_helper([]) + assert "FlashInfer CLI" in out or "Usage" in out + + +def test_download_cubin_flag_mocked(monkeypatch): + # This just tests that the flag is parsed correctly, so we can monkeypatch + # download_artifacts to avoid the latency of downloading cubins + monkeypatch.setattr("flashinfer.__main__.download_artifacts", lambda: None) + + out = _test_cmd_helper(["--download-cubin"]) + assert "All cubin download tasks completed successfully" in out + + +def test_download_cubin_cmd_mocked(monkeypatch): + """ + Test that download-cubin can download a single cubin using a mocked cubin path + """ + # Return a real cubin path relative to the repository so it can be downloaded + fmha_cubin = "fmhaSm100aKernel_QE4m3KvE2m1OE4m3H128PagedKvCausalP32VarSeqQ128Kv128PersistentContext.cubin" + monkeypatch.setattr( + "flashinfer.artifacts.get_cubin_file_list", + lambda: [f"{ArtifactPath.TRTLLM_GEN_FMHA}/{fmha_cubin}"], + ) + + out = _test_cmd_helper(["--download-cubin"]) + assert "All cubin download tasks completed successfully" in out + + +def test_list_cubins_cmd_real(monkeypatch): + out = _test_cmd_helper(["list-cubins"]) + assert "Cubin" in out and "Status" in out + + +def test_list_cubins_cmd_mocked(monkeypatch): + monkeypatch.setattr( + "flashinfer.__main__.get_artifacts_status", + lambda: (("foo.cubin", True), ("bar.cubin", False)), + ) + + out = _test_cmd_helper(["list-cubins"]) + assert "foo.cubin" in out and "bar.cubin" in out + + +def test_clear_cache_cmd_mocked(monkeypatch): + """ + Test that clear-cache command works without actually clearing the cache. + + This doesn't test much, just a basic sanity check. + """ + monkeypatch.setattr("flashinfer.__main__.clear_cache_dir", lambda: None) + + out = _test_cmd_helper(["clear-cache"]) + assert "Cache cleared successfully" in out + + +# TODO: add test that actually clears the cache +# need to check that there aren't side effects if we do this + + +def test_clear_cubin_cmd_mocked(monkeypatch): + """ + Test that clear-cubin command works without actually clearing the cubin. + + This doesn't test much, just a basic sanity check. + """ + monkeypatch.setattr("flashinfer.__main__.clear_cubin", lambda: None) + + out = _test_cmd_helper(["clear-cubin"]) + assert "Cubin cleared successfully" in out + + +# TODO: add test that actually clears the cubins +# need to check that there aren't side effects if we do this diff --git a/tests/cli/test_cli_cmds_gpu.py b/tests/cli/test_cli_cmds_gpu.py new file mode 100644 index 0000000000..2ab1c3a0c1 --- /dev/null +++ b/tests/cli/test_cli_cmds_gpu.py @@ -0,0 +1,43 @@ +""" +Tests the module-status and list-modules commands + +This is factored out from test_cli_cmds.py because these tests require a GPU. +""" + +from cli_cmd_helpers import _test_cmd_helper + + +_MOCKED_CUDA_ARCH_LIST = "7.5 8.0 8.9 9.0a 10.0a" + + +def test_module_status_cmd_mocked(monkeypatch): + """ + Test that module-status command runs without error and sanity checks the output + + The only mock is to set the CUDA architecture list via monkeypatch, for isolation. + """ + monkeypatch.setenv("FLASHINFER_CUDA_ARCH_LIST", _MOCKED_CUDA_ARCH_LIST) + out = _test_cmd_helper(["module-status"]) + assert "=== Summary ===" in out + assert "Total modules:" in out + assert "AOT compiled:" in out + assert "JIT compiled:" in out + assert "Not compiled:" in out + + +# TODO: test module-status command with different filters +# TODO: test module-status command with detailed output + + +def test_list_modules_cmd_mocked(monkeypatch): + """ + Test that list-modules command runs without error and sanity checks the output + + The only mock is to set the CUDA architecture list via monkeypatch, for isolation. + """ + monkeypatch.setenv("FLASHINFER_CUDA_ARCH_LIST", _MOCKED_CUDA_ARCH_LIST) + out = _test_cmd_helper(["list-modules"]) + assert "Available compilation modules:" in out + + +# TODO: test list-modules command with module name diff --git a/tests/cli/test_cli_show_config.py b/tests/cli/test_cli_show_config.py deleted file mode 100644 index 4df2568784..0000000000 --- a/tests/cli/test_cli_show_config.py +++ /dev/null @@ -1,16 +0,0 @@ -from click.testing import CliRunner - -from flashinfer.__main__ import cli - - -def test_show_config_cmd_smoke(monkeypatch): - runner = CliRunner() - result = runner.invoke(cli, ["show-config"]) - assert result.exit_code == 0, result.output - out = result.output - - # Basic sections present - assert "=== Torch Version Info ===" in out - assert "=== Environment Variables ===" in out - assert "=== Artifact Path ===" in out - assert "=== Downloaded Cubins ===" in out