Skip to content

Commit 4c2f2f1

Browse files
authored
Merge pull request #297146 from GaetanLepage/chex
python3Packages.jax: towards fixing dependencies
2 parents ebde306 + fb7f0da commit 4c2f2f1

File tree

4 files changed

+22
-8
lines changed

4 files changed

+22
-8
lines changed

pkgs/development/python-modules/distrax/default.nix

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,8 @@ buildPythonPackage rec {
8282
homepage = "https://github.com/deepmind/distrax";
8383
license = licenses.asl20;
8484
maintainers = with maintainers; [ onny ];
85+
# Several tests fail with:
86+
# AssertionError: [Chex] Assertion assert_type failed: Error in type compatibility check
87+
broken = true;
8588
};
8689
}

pkgs/development/python-modules/equinox/default.nix

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ buildPythonPackage rec {
4747

4848
pythonImportsCheck = [ "equinox" ];
4949

50+
disabledTests = [
51+
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
52+
"test_tracetime"
53+
];
54+
5055
meta = with lib; {
5156
description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
5257
changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";

pkgs/development/python-modules/flax/default.nix

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525

2626
buildPythonPackage rec {
2727
pname = "flax";
28-
version = "0.8.1";
28+
version = "0.8.2";
2929
pyproject = true;
3030

31-
disabled = pythonOlder "3.8";
31+
disabled = pythonOlder "3.9";
3232

3333
src = fetchFromGitHub {
3434
owner = "google";
3535
repo = "flax";
3636
rev = "refs/tags/v${version}";
37-
hash = "sha256-3UzMSJoKw+V1WLBJ+Zf7aF7CDNBsvWnRUfNgb3K4v1A=";
37+
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
3838
};
3939

4040
nativeBuildInputs = [
@@ -87,6 +87,7 @@ buildPythonPackage rec {
8787
# `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
8888
# would be limited anyway.
8989
"examples/*"
90+
"flax/experimental/nnx/examples/*"
9091
# See https://github.com/google/flax/issues/3232.
9192
"tests/jax_utils_test.py"
9293
# Requires tree

pkgs/development/python-modules/optax/default.nix

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{ lib
22
, absl-py
33
, buildPythonPackage
4+
, flit-core
45
, chex
56
, fetchFromGitHub
67
, jaxlib
@@ -11,23 +12,27 @@
1112

1213
buildPythonPackage rec {
1314
pname = "optax";
14-
version = "0.2.1";
15-
format = "setuptools";
15+
version = "0.2.2";
16+
pyproject = true;
1617

17-
disabled = pythonOlder "3.7";
18+
disabled = pythonOlder "3.9";
1819

1920
src = fetchFromGitHub {
2021
owner = "deepmind";
21-
repo = pname;
22+
repo = "optax";
2223
rev = "refs/tags/v${version}";
23-
hash = "sha256-vimsVZV5Z11euLxsu998pMQZ0hG3xl96D3h9iONtl/E=";
24+
hash = "sha256-sBiKUuQR89mttc9Njrh1aeUJOYdlcF7Nlj3/+Y7OMb4=";
2425
};
2526

2627
outputs = [
2728
"out"
2829
"testsout"
2930
];
3031

32+
nativeBuildInputs = [
33+
flit-core
34+
];
35+
3136
buildInputs = [
3237
jaxlib
3338
];

0 commit comments

Comments
 (0)