@@ -367,43 +367,58 @@ def test_poisson(self):
367367
368368 @pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
369369 def test_categorical (self , n ):
370+ domain = Domain (range (n ), dtype = "int64" , edges = (0 , n ))
371+ paramdomains = {"p" : Simplex (n )}
372+
370373 check_logp (
371374 pm .Categorical ,
372- Domain ( range ( n ), dtype = "int64" , edges = ( 0 , n )) ,
373- { "p" : Simplex ( n )} ,
375+ domain ,
376+ paramdomains ,
374377 lambda value , p : categorical_logpdf (value , p ),
375378 )
376379
377- def test_categorical_logp_batch_dims (self ):
380+ check_selfconsistency_discrete_logcdf (
381+ pm .Categorical ,
382+ domain ,
383+ paramdomains ,
384+ )
385+
386+ @pytest .mark .parametrize ("method" , (logp , logcdf ), ids = lambda x : x .__name__ )
387+ def test_categorical_logp_batch_dims (self , method ):
378388 # Core case
379389 p = np .array ([0.2 , 0.3 , 0.5 ])
380390 value = np .array (2.0 )
381- logp_expr = logp (pm .Categorical .dist (p = p , shape = value .shape ), value )
382- assert logp_expr .type .ndim == 0
383- np .testing .assert_allclose (logp_expr .eval (), np .log (0.5 ))
391+ expr = method (pm .Categorical .dist (p = p , shape = value .shape ), value )
392+ assert expr .type .ndim == 0
393+ expected_p = 0.5 if method is logp else 1.0
394+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
384395
385396 # Explicit batched value broadcasts p
386397 bcast_p = p [None ] # shape (1, 3)
387398 batch_value = np .array ([0 , 1 ]) # shape(3,)
388- logp_expr = logp (pm .Categorical .dist (p = bcast_p , shape = batch_value .shape ), batch_value )
389- assert logp_expr .type .ndim == 1
390- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.3 ]))
399+ expr = method (pm .Categorical .dist (p = bcast_p , shape = batch_value .shape ), batch_value )
400+ assert expr .type .ndim == 1
401+ expected_p = [0.2 , 0.3 ] if method is logp else [0.2 , 0.5 ]
402+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
403+
404+ # Implicit batch value broadcasts p
405+ expr = method (pm .Categorical .dist (p = p , shape = ()), batch_value )
406+ assert expr .type .ndim == 1
407+ expected_p = [0.2 , 0.3 ] if method is logp else [0.2 , 0.5 ]
408+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
391409
392410 # Explicit batched value and batched p
393411 batch_p = np .array ([p [::- 1 ], p ])
394- logp_expr = logp (pm .Categorical .dist (p = batch_p , shape = batch_value .shape ), batch_value )
395- assert logp_expr .type .ndim == 1
396- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.5 , 0.3 ]))
397-
398- # Implicit batch value broadcasts p
399- logp_expr = logp (pm .Categorical .dist (p = p , shape = ()), batch_value )
400- assert logp_expr .type .ndim == 1
401- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.3 ]))
412+ expr = method (pm .Categorical .dist (p = batch_p , shape = batch_value .shape ), batch_value )
413+ assert expr .type .ndim == 1
414+ expected_p = [0.5 , 0.3 ] if method is logp else [0.5 , 0.5 ]
415+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
402416
403417 # Implicit batch p broadcasts value
404- logp_expr = logp (pm .Categorical .dist (p = batch_p , shape = None ), value )
405- assert logp_expr .type .ndim == 1
406- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.5 ]))
418+ expr = method (pm .Categorical .dist (p = batch_p , shape = None ), value )
419+ assert expr .type .ndim == 1
420+ expected_p = [0.2 , 0.5 ] if method is logp else [1.0 , 1.0 ]
421+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
407422
408423 @pytensor .config .change_flags (compute_test_value = "raise" )
409424 def test_categorical_bounds (self ):
0 commit comments