@@ -1801,3 +1801,110 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
18011801 match = "The samples argument has been deprecated"
18021802 with pytest .warns (DeprecationWarning , match = match ):
18031803 pm .sample_prior_predictive (model = m , samples = 10 )
1804+
1805+
1806+ class TestSamplePrior :
1807+ def test_basic_prior_sampling (self , seeded_test ):
1808+ """Test that sample_prior only samples from unobserved random variables."""
1809+ with pm .Model () as model :
1810+ mu = pm .Normal ("mu" , 0 , 1 )
1811+ sigma = pm .HalfNormal ("sigma" , 1 )
1812+ y = pm .Normal ("y" , mu , sigma , observed = [1 , 2 , 3 ])
1813+ det = pm .Deterministic ("det" , mu + sigma )
1814+
1815+ prior_samples = pm .sample_prior (draws = 100 , return_inferencedata = False )
1816+
1817+ # Should contain unobserved RVs and deterministics that do not
1818+ # depend on observed variables, but not observed variables
1819+ assert "mu" in prior_samples
1820+ assert "sigma" in prior_samples
1821+ assert "det" in prior_samples # deterministic is included
1822+ assert "y" not in prior_samples # observed variable
1823+
1824+ assert prior_samples ["mu" ].shape == (100 ,)
1825+ assert prior_samples ["sigma" ].shape == (100 ,)
1826+ assert prior_samples ["det" ].shape == (100 ,)
1827+
1828+ def test_specific_var_names (self , seeded_test ):
1829+ """Test sampling specific variables from the prior."""
1830+ with pm .Model () as model :
1831+ mu = pm .Normal ("mu" , 0 , 1 )
1832+ sigma = pm .HalfNormal ("sigma" , 1 )
1833+ y = pm .Normal ("y" , mu , sigma , observed = [1 , 2 , 3 ])
1834+
1835+ # Sample only mu
1836+ mu_samples = pm .sample_prior (draws = 100 , var_names = ["mu" ], return_inferencedata = False )
1837+
1838+ assert "mu" in mu_samples
1839+ assert "sigma" not in mu_samples
1840+ assert "y" not in mu_samples
1841+ assert mu_samples ["mu" ].shape == (100 ,)
1842+
1843+ def test_multivariate_prior (self , seeded_test ):
1844+ """Test sampling from multivariate priors."""
1845+ with pm .Model () as model :
1846+ mu = pm .Normal ("mu" , 0 , 1 , size = 3 )
1847+ # Use a simpler multivariate setup to avoid LKJCholeskyCov issues
1848+ mv = pm .MvNormal ("mv" , mu , cov = np .eye (3 ), size = 4 )
1849+
1850+ prior_samples = pm .sample_prior (draws = 50 , return_inferencedata = False )
1851+
1852+ assert "mu" in prior_samples
1853+ assert "mv" in prior_samples
1854+ assert prior_samples ["mu" ].shape == (50 , 3 )
1855+ assert prior_samples ["mv" ].shape == (50 , 4 , 3 )
1856+
1857+ def test_only_requested_variables (self , seeded_test ):
1858+ """Test that sample_prior only returns the requested variables."""
1859+ with pm .Model () as model :
1860+ mu = pm .Normal ("mu" , 0 , 1 )
1861+ sigma = pm .HalfNormal ("sigma" , 1 )
1862+ det = pm .Deterministic ("det" , mu + sigma )
1863+ y = pm .Normal ("y" , det , sigma , observed = [1 , 2 , 3 ])
1864+
1865+ # Request only mu, but y depends on det which depends on mu
1866+ prior_samples = pm .sample_prior (draws = 100 , var_names = ["mu" ], return_inferencedata = False )
1867+
1868+ # Should only contain mu, not det or sigma even though they're dependencies
1869+ assert "mu" in prior_samples
1870+ assert "sigma" not in prior_samples
1871+ assert "det" not in prior_samples
1872+ assert "y" not in prior_samples
1873+ assert len (prior_samples ) == 1
1874+
1875+ def test_deterministics_behavior (self , seeded_test ):
1876+ """Test that sample_prior only includes deterministics that don't depend on observed variables."""
1877+ with pm .Model () as model :
1878+ mu = pm .Normal ("mu" , 0 , 1 )
1879+ sigma = pm .HalfNormal ("sigma" , 1 )
1880+ y = pm .Normal ("y" , mu , sigma , observed = [1 , 2 , 3 ])
1881+
1882+ # Deterministic that depends only on unobserved RVs
1883+ det_prior = pm .Deterministic ("det_prior" , mu + sigma )
1884+
1885+ # Deterministic that depends on observed RV
1886+ det_obs = pm .Deterministic ("det_obs" , y + mu )
1887+
1888+ prior_samples = pm .sample_prior (draws = 100 , return_inferencedata = False )
1889+
1890+ # Should include deterministics that depend only on unobserved RVs
1891+ assert "det_prior" in prior_samples
1892+
1893+ # Should NOT include deterministics that depend on observed variables
1894+ assert "det_obs" not in prior_samples
1895+
1896+ # Should not include observed variables
1897+ assert "y" not in prior_samples
1898+
1899+ def test_empty_var_names_behavior (self , seeded_test ):
1900+ """Test what happens when we pass an empty var_names set."""
1901+ with pm .Model () as model :
1902+ mu = pm .Normal ("mu" , 0 , 1 )
1903+ sigma = pm .HalfNormal ("sigma" , 1 )
1904+ y = pm .Normal ("y" , mu , sigma , observed = [1 , 2 , 3 ])
1905+ det = pm .Deterministic ("det" , mu + sigma )
1906+
1907+ # Test with empty var_names
1908+ empty_samples = pm .sample_prior (draws = 100 , var_names = [], return_inferencedata = False )
1909+
1910+ assert empty_samples == {}
0 commit comments