Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2023

This bug showed up in pymc-devs/pymc#6947

import pytensor.tensor as pt

p = pt.ones(3) / 3
x = pt.random.categorical(p=pt.stack([p, 1-p], axis=-1))
assert x.type.shape == (3,)  # AssertionError

Which would later lead to a rewrite error during compilation.

It was caused by the presence of np.int in the static shape of Join, (and corresponding np.bool in the broadcastable), which would then be overlooked by an explicit check broadcastable is False in the RandomVariable.infer_shape.

@codecov-commenter
Copy link

Codecov Report

Merging #475 (f236f96) into main (36df379) will increase coverage by 0.00%.
Report is 1 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #475   +/-   ##
=======================================
  Coverage   80.66%   80.66%           
=======================================
  Files         160      160           
  Lines       46025    46029    +4     
  Branches    11266    11268    +2     
=======================================
+ Hits        37124    37128    +4     
  Misses       6668     6668           
  Partials     2233     2233           
Files Coverage Δ
pytensor/tensor/random/op.py 96.25% <ø> (ø)
pytensor/tensor/type.py 94.52% <100.00%> (+0.04%) ⬆️

@michaelosthege michaelosthege merged commit 6834740 into pymc-devs:main Oct 11, 2023
@ricardoV94 ricardoV94 changed the title Fix static type shape bug Fix RandomVariable static type shape bug Oct 12, 2023
@ricardoV94 ricardoV94 deleted the fix_static_type_shape_bug branch October 12, 2023 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working shape inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants