11using HybridVariationalInference
22using Test
3+ using CUDA, GPUArraysCore
34using Lux
45using StatsFuns: logistic
5- using CUDA, GPUArraysCore
66
77
88@testset " LuxModelApplicator" begin
@@ -13,18 +13,20 @@ using CUDA, GPUArraysCore
1313 Dense (n_covar * 4 => n_covar * 4 , tanh),
1414 Dense (n_covar * 4 => n_out, logistic, use_bias= false ),
1515 );
16- g = construct_LuxApplicator (g_chain; device = cpu_device ());
16+ g, ϕ = construct_LuxApplicator (g_chain, Float64; device = cpu_device ());
17+ @test eltype (ϕ) == Float64
18+ g, ϕ = construct_LuxApplicator (g_chain; device = cpu_device ());
19+ @test eltype (ϕ) == Float32
1720 n_site = 3
1821 x = rand (Float32, n_covar, n_site)
19- ϕ = randn (Float32, Lux. parameterlength (g_chain))
22+ # ϕ = randn(Float32, Lux.parameterlength(g_chain))
2023 y = g (x, ϕ)
2124 @test size (y) == (n_out, n_site)
2225 #
23- g = construct_LuxApplicator (g_chain; device = gpu_device ());
24- n_site = 3
2526 x = rand (Float32, n_covar, n_site) |> gpu_device ()
26- ϕ = randn (Float32, Lux. parameterlength (g_chain)) |> gpu_device ()
27- y = g (x, ϕ)
27+ ϕ_gpu = ϕ |> gpu_device ()
28+ # ϕ = randn(Float32, Lux.parameterlength(g_chain)) |> gpu_device()
29+ y = g (x, ϕ_gpu)
2830 # @test ϕ isa GPUArraysCore.AbstractGPUArray
2931 @test size (y) == (n_out, n_site)
3032end ;
0 commit comments