|
| 1 | +{ |
| 2 | + lib, |
| 3 | + stdenv, |
| 4 | + fetchFromGitHub, |
| 5 | + pythonOlder, |
| 6 | + buildPythonPackage, |
| 7 | + pytestCheckHook, |
| 8 | + setuptools, |
| 9 | + matplotlib, |
| 10 | + numpy, |
| 11 | + packaging, |
| 12 | + torch, |
| 13 | + tqdm, |
| 14 | + flask, |
| 15 | + flask-compress, |
| 16 | +}: |
| 17 | + |
| 18 | +buildPythonPackage rec { |
| 19 | + pname = "captum"; |
| 20 | + version = "0.7.0"; |
| 21 | + pyproject = true; |
| 22 | + |
| 23 | + build-system = [ setuptools ]; |
| 24 | + |
| 25 | + src = fetchFromGitHub { |
| 26 | + owner = "pytorch"; |
| 27 | + repo = "captum"; |
| 28 | + rev = "refs/tags/v${version}"; |
| 29 | + hash = "sha256-1VOvPqxn6CNnmv7M8fl7JrqRfJQUH2tnXRCUqKnl7i0="; |
| 30 | + }; |
| 31 | + |
| 32 | + dependencies = [ |
| 33 | + matplotlib |
| 34 | + numpy |
| 35 | + packaging |
| 36 | + torch |
| 37 | + tqdm |
| 38 | + ]; |
| 39 | + |
| 40 | + pythonImportsCheck = [ "captum" ]; |
| 41 | + |
| 42 | + nativeCheckInputs = [ |
| 43 | + pytestCheckHook |
| 44 | + flask |
| 45 | + flask-compress |
| 46 | + ]; |
| 47 | + |
| 48 | + disabledTestPaths = |
| 49 | + [ |
| 50 | + # These tests requires `parametrized` module (https://pypi.org/project/parametrized/) which seem to be unavailable on Nix. |
| 51 | + "tests/attr/test_dataloader_attr.py" |
| 52 | + "tests/attr/test_interpretable_input.py" |
| 53 | + "tests/attr/test_llm_attr.py" |
| 54 | + "tests/influence/_core/test_dataloader.py" |
| 55 | + "tests/influence/_core/test_tracin_aggregate_influence.py" |
| 56 | + "tests/influence/_core/test_tracin_intermediate_quantities.py" |
| 57 | + "tests/influence/_core/test_tracin_k_most_influential.py" |
| 58 | + "tests/influence/_core/test_tracin_regression.py" |
| 59 | + "tests/influence/_core/test_tracin_self_influence.py" |
| 60 | + "tests/influence/_core/test_tracin_show_progress.py" |
| 61 | + "tests/influence/_core/test_tracin_validation.py" |
| 62 | + "tests/influence/_core/test_tracin_xor.py" |
| 63 | + "tests/insights/test_contribution.py" |
| 64 | + "tests/module/test_binary_concrete_stochastic_gates.py" |
| 65 | + "tests/module/test_gaussian_stochastic_gates.py" |
| 66 | + ] |
| 67 | + ++ lib.optionals stdenv.hostPlatform.isDarwin [ |
| 68 | + # These tests are failing on macOS: |
| 69 | + # > E AttributeError: module 'torch.distributed' has no attribute 'init_process_group' |
| 70 | + "tests/attr/test_data_parallel.py" |
| 71 | + ] |
| 72 | + ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [ |
| 73 | + # Issue reported upstream at https://github.com/pytorch/captum/issues/1447 |
| 74 | + "tests/concept/test_tcav.py" |
| 75 | + ]; |
| 76 | + |
| 77 | + disabledTests = [ |
| 78 | + # Failing tests |
| 79 | + "test_softmax_classification_batch_multi_target" |
| 80 | + "test_softmax_classification_batch_zero_baseline" |
| 81 | + ]; |
| 82 | + |
| 83 | + meta = { |
| 84 | + description = "Model interpretability and understanding for PyTorch"; |
| 85 | + homepage = "https://github.com/pytorch/captum"; |
| 86 | + license = lib.licenses.bsd3; |
| 87 | + maintainers = with lib.maintainers; [ drupol ]; |
| 88 | + }; |
| 89 | +} |
0 commit comments