@@ -964,3 +964,28 @@ def scan_step(prev_innov):
964
964
"innov" : np .full ((4 ,), - 0.5 ),
965
965
}
966
966
np .testing .assert_allclose (logp_fn (** test_point ), ref_logp_fn (** test_point ))
967
+
968
+
969
+ @pytest .mark .parametrize ("shift" , [1.5 , np .array ([- 0.5 , 1 , 0.3 ])])
970
+ @pytest .mark .parametrize ("scale" , [2.0 , np .array ([1.5 , 3.3 , 1.0 ])])
971
+ def test_multivariate_transform (shift , scale ):
972
+ mu = np .array ([0 , 0.9 , - 2.1 ])
973
+ cov = np .array ([[1 , 0 , 0.9 ], [0 , 1 , 0 ], [0.9 , 0 , 1 ]])
974
+ x_rv_raw = pt .random .multivariate_normal (mu , cov = cov )
975
+ x_rv = shift + x_rv_raw * scale
976
+ x_rv .name = "x"
977
+
978
+ x_vv = x_rv .clone ()
979
+ logp = factorized_joint_logprob ({x_rv : x_vv })[x_vv ]
980
+ assert_no_rvs (logp )
981
+
982
+ x_vv_test = np .array ([5.0 , 4.9 , - 6.3 ])
983
+ scale_mat = scale * np .eye (x_vv_test .shape [0 ])
984
+ np .testing .assert_almost_equal (
985
+ logp .eval ({x_vv : x_vv_test }),
986
+ sp .stats .multivariate_normal .logpdf (
987
+ x_vv_test ,
988
+ shift + mu * scale ,
989
+ scale_mat @ cov @ scale_mat .T ,
990
+ ),
991
+ )
0 commit comments