Skip to content

Commit a2d0eaa

Browse files
committed
update speed test
1 parent 7979fd1 commit a2d0eaa

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

test/speed_tests.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@ Test.@testset "Speed Tests" begin
2222
@show compute_mode
2323

2424
rng = StableRNGs.StableRNG(1)
25-
nvars = 1
25+
ndata = 2^10
26+
ndimension = 1
27+
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
28+
r = rand(rng, data_dist, ndimension, ndata)
29+
r = convert.(Float32, r)
30+
31+
nvars = size(r, 1)
2632
naugs = nvars
2733
n_in = nvars + naugs
28-
n = 2^10
34+
2935
nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh))
3036

3137
icnf = ContinuousNormalizingFlows.construct(
@@ -47,10 +53,6 @@ Test.@testset "Speed Tests" begin
4753
),
4854
)
4955

50-
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
51-
r = rand(icnf.rng, data_dist, nvars, n)
52-
r = convert.(Float32, r)
53-
5456
df = DataFrames.DataFrame(transpose(r), :auto)
5557
model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5)
5658

0 commit comments

Comments
 (0)