Skip to content

Commit 524a95c

Browse files
committed
Improved searching mechanism for matching regular dists
1 parent 4d11ea1 commit 524a95c

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,21 +202,20 @@ def create_scalar_dims_docstrings():
202202

203203
import pymc.distributions.continuous as _regular_dists
204204

205-
from pymc.distributions import Continuous
206-
207-
# Get all subclasses of Continuous class
208-
imported_dists = {
209-
name: cls
210-
for name, cls in inspect.getmembers(_regular_dists, inspect.isclass)
211-
if issubclass(cls, Continuous) and cls is not Continuous
212-
}
205+
from pymc.distributions import Distribution
213206

214207
# Get all classes declared in this file
215-
dims_dists = {
216-
name: cls
217-
for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass)
218-
if issubclass(cls, DimDistribution) and cls.__module__ == __name__
219-
}
208+
dims_dists = {}
209+
for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass):
210+
if issubclass(cls, DimDistribution) and cls.__module__ == __name__:
211+
dims_dists[name] = cls
212+
213+
# Get all subclasses of Distribution class that match the name of the Dims classes
214+
imported_dists = {}
215+
for name in dims_dists:
216+
imported_cls = getattr(_regular_dists, name, None)
217+
if imported_cls is not None and issubclass(imported_cls, Distribution):
218+
imported_dists[name] = imported_cls
220219

221220
# Copy docstring from regular distribution to dims distribution
222221
for dist_class_name in dims_dists:
@@ -225,4 +224,4 @@ def create_scalar_dims_docstrings():
225224
if imported_cls and imported_cls.__doc__ and dims_cls.__doc__ is None:
226225
dims_cls.__doc__ = imported_cls.__doc__.replace("tensor_like", "xtensor_like")
227226

228-
del _regular_dists, Continuous, inspect, sys
227+
del _regular_dists, Distribution, inspect, sys

0 commit comments

Comments
 (0)