@@ -351,26 +351,26 @@ end
351351 selection = select (:bar => :z , :a , :out )
352352 selection = StaticSelection (selection)
353353 retval_grad = 2.
354- ((mu_a_grad,), choice_value, choice_gradient ) = choice_gradients (trace, selection, retval_grad)
354+ ((mu_a_grad,), value_choices, gradient_choices ) = choice_gradients (trace, selection, retval_grad)
355355
356356 # check input gradient
357357 @test isapprox (mu_a_grad, finite_diff (f, (mu_a, theta, a, b, z, out), 1 , dx))
358358
359359 # check value from choice map
360- @test get_value (choice_value , :a ) == a
361- @test get_value (choice_value , :out ) == out
362- @test get_value (choice_value , :bar => :z ) == z
363- @test ! has_value (choice_value , :b ) # was not selected
364- @test length (get_submaps_shallow (choice_value )) == 1
365- @test length (get_values_shallow (choice_value )) == 2
360+ @test get_value (value_choices , :a ) == a
361+ @test get_value (value_choices , :out ) == out
362+ @test get_value (value_choices , :bar => :z ) == z
363+ @test ! has_value (value_choices , :b ) # was not selected
364+ @test length (get_submaps_shallow (value_choices )) == 1
365+ @test length (get_values_shallow (value_choices )) == 2
366366
367367 # check gradient from choice map
368- @test length (get_submaps_shallow (choice_gradient )) == 1
369- @test length (get_values_shallow (choice_gradient )) == 2
370- @test ! has_value (choice_gradient , :b ) # was not selected
371- @test isapprox (get_value (choice_gradient , :a ), finite_diff (f, (mu_a, theta, a, b, z, out), 3 , dx))
372- @test isapprox (get_value (choice_gradient , :out ), finite_diff (f, (mu_a, theta, a, b, z, out), 6 , dx))
373- @test isapprox (get_value (choice_gradient , :bar => :z ), finite_diff (f, (mu_a, theta, a, b, z, out), 5 , dx))
368+ @test length (get_submaps_shallow (gradient_choices )) == 1
369+ @test length (get_values_shallow (gradient_choices )) == 2
370+ @test ! has_value (gradient_choices , :b ) # was not selected
371+ @test isapprox (get_value (gradient_choices , :a ), finite_diff (f, (mu_a, theta, a, b, z, out), 3 , dx))
372+ @test isapprox (get_value (gradient_choices , :out ), finite_diff (f, (mu_a, theta, a, b, z, out), 6 , dx))
373+ @test isapprox (get_value (gradient_choices , :bar => :z ), finite_diff (f, (mu_a, theta, a, b, z, out), 5 , dx))
374374
375375 # reset the trainable parameter gradient
376376 zero_param_grad! (foo, :theta )
0 commit comments