@@ -202,21 +202,20 @@ def create_scalar_dims_docstrings():
202202
203203 import pymc .distributions .continuous as _regular_dists
204204
205- from pymc .distributions import Continuous
206-
207- # Get all subclasses of Continuous class
208- imported_dists = {
209- name : cls
210- for name , cls in inspect .getmembers (_regular_dists , inspect .isclass )
211- if issubclass (cls , Continuous ) and cls is not Continuous
212- }
205+ from pymc .distributions import Distribution
213206
214207 # Get all classes declared in this file
215- dims_dists = {
216- name : cls
217- for name , cls in inspect .getmembers (sys .modules [__name__ ], inspect .isclass )
218- if issubclass (cls , DimDistribution ) and cls .__module__ == __name__
219- }
208+ dims_dists = {}
209+ for name , cls in inspect .getmembers (sys .modules [__name__ ], inspect .isclass ):
210+ if issubclass (cls , DimDistribution ) and cls .__module__ == __name__ :
211+ dims_dists [name ] = cls
212+
213+ # Get all subclasses of Distribution class that match the name of the Dims classes
214+ imported_dists = {}
215+ for name in dims_dists :
216+ imported_cls = getattr (_regular_dists , name , None )
217+ if imported_cls is not None and issubclass (imported_cls , Distribution ):
218+ imported_dists [name ] = imported_cls
220219
221220 # Copy docstring from regular distribution to dims distribution
222221 for dist_class_name in dims_dists :
@@ -225,4 +224,4 @@ def create_scalar_dims_docstrings():
225224 if imported_cls and imported_cls .__doc__ and dims_cls .__doc__ is None :
226225 dims_cls .__doc__ = imported_cls .__doc__ .replace ("tensor_like" , "xtensor_like" )
227226
228- del _regular_dists , Continuous , inspect , sys
227+ del _regular_dists , Distribution , inspect , sys
0 commit comments