@@ -59,15 +59,17 @@ function finitediff_withgradient(f, x...)
5959end 
6060
6161function  check_equal_leaves (a, b; rtol= 1e-4 , atol= 1e-4 )
62+     equal =  true 
6263    fmapstructure_with_path (a, b) do  kp, x, y
6364        if  x isa  AbstractArray
6465            #  @show kp
65-             @assert  x  ≈  y  rtol= rtol  atol= atol 
66-         #  elseif x isa Number 
67-         #      @show kp 
68-         #      @assert x ≈ y rtol=rtol atol=atol 
66+             #   @assert isapprox(x, y;  rtol,  atol) 
67+              if   ! isapprox (x, y; rtol, atol) 
68+                 equal  =   false 
69+             end 
6970        end 
7071    end 
72+     @assert  equal
7173end 
7274
7375function  test_gradients (
@@ -109,15 +111,15 @@ function test_gradients(
109111            f64 =  f |>  Flux. f64
110112            xs64 =  xs .| >  Flux. f64
111113            y_fd, g_fd =  finitediff_withgradient ((xs... ) ->  loss (f64, graph, xs... ), xs64... )
112-             @assert  y  ≈   y_fd rtol= rtol  atol= atol 
114+             @assert  isapprox (y,  y_fd;  rtol,  atol) 
113115            check_equal_leaves (g, g_fd; rtol, atol)
114116        end 
115117
116118        if  test_gpu
117119            #  Zygote gradient with respect to input on GPU.
118120            y_gpu, g_gpu =  Zygote. withgradient ((xs... ) ->  loss (f_gpu, graph_gpu, xs... ), xs_gpu... )
119121            @assert  get_device (g_gpu) ==  get_device (xs_gpu)
120-             @assert  y_gpu  ≈  y  rtol= rtol  atol= atol 
122+             @assert  isapprox ( y_gpu, y;  rtol,  atol) 
121123            check_equal_leaves (g_gpu |>  cpu_dev, g; rtol, atol)
122124        end 
123125    end 
@@ -132,23 +134,22 @@ function test_gradients(
132134            ps, re =  Flux. destructure (f64)
133135            y_fd, g_fd =  finitediff_withgradient (ps ->  loss (re (ps),graph, xs... ), ps)
134136            g_fd =  (re (g_fd[1 ]),)
135-             @assert  y  ≈   y_fd rtol= rtol  atol= atol 
137+             @assert  isapprox (y,  y_fd;  rtol,  atol) 
136138            check_equal_leaves (g, g_fd; rtol, atol)
137139        end 
138140
139141        if  test_gpu
140142            #  Zygote gradient with respect to f on GPU.
141143            y_gpu, g_gpu =  Zygote. withgradient (f ->  loss (f,graph_gpu, xs_gpu... ), f_gpu)
142144            #  @assert get_device(g_gpu) == get_device(xs_gpu)
143-             @assert  y_gpu  ≈  y  rtol= rtol  atol= atol 
145+             @assert  isapprox ( y_gpu, y;  rtol,  atol) 
144146            check_equal_leaves (g_gpu |>  cpu_dev, g; rtol, atol)
145147        end 
146148    end 
147149    @test  true  #  if we reach here, the test passed
148-     return  true   #  return true in case we want to put a @test_broken in the caller 
150+     return  true 
149151end 
150152
151- 
152153function  generate_test_graphs (graph_type)
153154    adj1 =  [0  1  0  1 
154155            1  0  1  0 
0 commit comments