1
+ using Reactant, Test, Random
2
+ using Reactant: ProbProg
3
+
4
+ # Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/
5
+
6
+ normal (rng, μ, σ, shape) = μ .+ σ .* randn (rng, shape)
7
+ normal_logpdf (x, μ, σ, _) = - sum (log .(σ)) - sum ((μ .- x) .^ 2 ) / (2 * σ^ 2 )
8
+
9
+ function my_model (seed, xs)
10
+ rng = Random. default_rng ()
11
+ Random. seed! (rng, seed)
12
+
13
+ slope = ProbProg. sample (
14
+ normal, rng, 0.0 , 2.0 , (1 ,); symbol= :slope , logpdf= normal_logpdf
15
+ )
16
+ intercept = ProbProg. sample (
17
+ normal, rng, 0.0 , 10.0 , (1 ,); symbol= :intercept , logpdf= normal_logpdf
18
+ )
19
+
20
+ ys = ProbProg. sample (
21
+ normal,
22
+ rng,
23
+ slope .* xs .+ intercept,
24
+ 1.0 ,
25
+ (length (xs),);
26
+ symbol= :ys ,
27
+ logpdf= normal_logpdf,
28
+ )
29
+
30
+ return rng. seed, ys
31
+ end
32
+
33
+ function my_inference_program (xs, ys, num_iters)
34
+ xs_r = Reactant. to_rarray (xs)
35
+
36
+ constraints = ProbProg. choicemap ()
37
+ constraints[:ys ] = [ys]
38
+
39
+ seed = Reactant. to_rarray (UInt64[1 , 4 ])
40
+
41
+ trace, _ = ProbProg. generate (my_model, seed, xs_r; constraints)
42
+ trace. args = (trace. retval[1 ], trace. args[2 : end ]. .. ) # TODO : this is a temporary hack
43
+
44
+ for i in 1 : num_iters
45
+ trace, _ = ProbProg. metropolis_hastings (trace, ProbProg. select (:slope ))
46
+ trace, _ = ProbProg. metropolis_hastings (trace, ProbProg. select (:intercept ))
47
+ choices = ProbProg. get_choices (trace)
48
+ @show i, choices[:slope ], choices[:intercept ]
49
+ end
50
+
51
+ choices = ProbProg. get_choices (trace)
52
+ return (choices[:slope ], choices[:intercept ])
53
+ end
54
+
55
+ @testset " linear_regression" begin
56
+ @testset " simulate" begin
57
+ seed = Reactant. to_rarray (UInt64[1 , 4 ])
58
+
59
+ xs = [1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 , 10.0 ]
60
+ xs_r = Reactant. to_rarray (xs)
61
+
62
+ trace = ProbProg. simulate (my_model, seed, xs_r)
63
+
64
+ @test haskey (trace. choices, :slope )
65
+ @test haskey (trace. choices, :intercept )
66
+ @test haskey (trace. choices, :ys )
67
+ end
68
+
69
+ @testset " inference" begin
70
+ xs = [1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 , 10.0 ]
71
+ ys = [8.23 , 5.87 , 3.99 , 2.59 , 0.23 , - 0.66 , - 3.53 , - 6.91 , - 7.24 , - 9.90 ]
72
+
73
+ slope, intercept = my_inference_program (xs, ys, 1000 )
74
+
75
+ @show slope, intercept
76
+ end
77
+ end
0 commit comments