11{
22 lib ,
33 buildPythonPackage ,
4- fetchFromGitHub ,
4+ fetchPypi ,
55
66 # build-system
77 poetry-core ,
1818
1919buildPythonPackage rec {
2020 pname = "oryx" ;
21- version = "0.2.7 " ;
21+ version = "0.2.9 " ;
2222 pyproject = true ;
2323
24- src = fetchFromGitHub {
25- owner = "jax-ml" ;
26- repo = "oryx" ;
27- tag = "v${ version } " ;
28- hash = "sha256-1n7ogGuFNAeOyXWe0/pAouhg2+aA3MXxlCcsrfqRTdU=" ;
24+ # No more tags on GitHub. See https://github.com/jax-ml/oryx/issues/95
25+ src = fetchPypi {
26+ inherit pname version ;
27+ hash = "sha256-HlKUnguTNfs7gSqIJ0n2EjjLXPUgtI2JsQM70wKMeXs=" ;
2928 } ;
3029
3130 build-system = [ poetry-core ] ;
@@ -43,6 +42,72 @@ buildPythonPackage rec {
4342 pytestCheckHook
4443 ] ;
4544
45+ disabledTests = [
46+ # ValueError: Number of devices 1 must equal the product of mesh_shape (1, 2)
47+ "test_plant"
48+ "test_plant_before_shmap"
49+ "test_plant_inside_shmap_fails"
50+ "test_reap"
51+ "test_reap_before_shmap"
52+ "test_reap_inside_shmap_fails"
53+
54+ # ValueError: Variable has already been reaped
55+ "test_call_list"
56+ "test_call_tuple"
57+ "test_dense_combinator"
58+ "test_dense_function"
59+ "test_dense_imperative"
60+ "test_function_in_combinator_in_function"
61+ "test_grad_of_function_with_literal"
62+ "test_grad_of_shared_layer"
63+ "test_grad_of_stateful_function"
64+ "test_kwargs_rng"
65+ "test_kwargs_training"
66+ "test_kwargs_training_rng"
67+ "test_reshape_call"
68+ "test_scale_by_adam_should_scale_by_adam"
69+ "test_scale_by_schedule_should_update_scale"
70+ "test_scale_by_stddev_should_scale_by_stddev"
71+ "test_trace_should_keep_track_of_momentum_with_nesterov"
72+
73+ # NotImplementedError: No registered inverse for `split`
74+ "test_inverse_of_split"
75+
76+ # jax.errors.UnexpectedTracerError: Encountered an unexpected tracer
77+ "test_can_plant_into_jvp_of_custom_jvp_function_unimplemented"
78+ "test_forward_Scale"
79+
80+ # ValueError: No variable declared for assign: update_1
81+ "test_optimizer_adam"
82+ "test_optimizer_noisy_sgd"
83+ "test_optimizer_rmsprop"
84+ "test_optimizer_sgd"
85+ "test_optimizer_sgd_with_momentum"
86+ "test_optimizer_sgd_with_nesterov_momentum"
87+
88+ # AssertionError
89+ # ACTUAL: array(-2.337877, dtype=float32)
90+ # DESIRED: array(0., dtype=float32)
91+ "test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob"
92+ "test_vmapping_distribution_reduces_to_scalar_log_prob"
93+
94+ # TypeError: _dot_general_shape_rule() missing 1 required keyword-only argument: 'out_sharding'
95+ "test_can_rewrite_dot_to_einsu"
96+
97+ # AttributeError: 'float' object has no attribute 'shape'
98+ "test_add_noise_should_add_noise"
99+ "test_apply_every_should_delay_updates"
100+
101+ # TypeError: Error interpreting argument to functools.partial(...) as an abstract array
102+ "test_can_rewrite_nested_expression_into_single_einsum"
103+ ] ;
104+
105+ disabledTestPaths = [
106+ # ValueError: Variable has already been reaped
107+ "oryx/experimental/nn/normalization_test.py"
108+ "oryx/experimental/nn/pooling_test.py"
109+ ] ;
110+
46111 meta = {
47112 description = "Library for probabilistic programming and deep learning built on top of Jax" ;
48113 homepage = "https://github.com/jax-ml/oryx" ;
0 commit comments