Skip to content

Commit 134f4a1

Browse files
committed
Fix lazy import clobbering figure
1 parent b0a06d5 commit 134f4a1

File tree

1 file changed

+101
-148
lines changed

1 file changed

+101
-148
lines changed

ultraplot/__init__.py

Lines changed: 101 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
_ATTR_MAP = None
2525
_REGISTRY_ATTRS = None
2626

27-
# Exceptions to the automated lazy loading
2827
_LAZY_LOADING_EXCEPTIONS = {
2928
"constructor": ("constructor", None),
3029
"crs": ("proj", None),
@@ -37,10 +36,15 @@
3736
"PROJS": ("constructor", "PROJS"),
3837
"internals": ("internals", None),
3938
"externals": ("externals", None),
39+
"Proj": ("constructor", "Proj"),
4040
"tests": ("tests", None),
4141
"rcsetup": ("internals", "rcsetup"),
4242
"warnings": ("internals", "warnings"),
43-
"Figure": ("figure", "Figure"),
43+
"figure": ("ui", "figure"), # Points to the FUNCTION in ui.py
44+
"Figure": ("figure", "Figure"), # Points to the CLASS in figure.py
45+
"Colormap": ("constructor", "Colormap"),
46+
"Cycle": ("constructor", "Cycle"),
47+
"Norm": ("constructor", "Norm"),
4448
}
4549

4650

@@ -70,52 +74,13 @@ def _parse_all(path):
7074
return None
7175

7276

73-
def _discover_modules():
74-
global _ATTR_MAP
75-
if _ATTR_MAP is not None:
76-
return
77-
78-
attr_map = {}
79-
base = Path(__file__).resolve().parent
80-
81-
for path in base.glob("*.py"):
82-
if path.name.startswith("_") or path.name == "setup.py":
83-
continue
84-
module_name = path.stem
85-
names = _parse_all(path)
86-
if names:
87-
if len(names) == 1:
88-
attr_map[module_name] = (module_name, names[0])
89-
else:
90-
for name in names:
91-
attr_map[name] = (module_name, name)
92-
93-
for path in base.iterdir():
94-
if not path.is_dir() or path.name.startswith("_") or path.name == "tests":
95-
continue
96-
if (path / "__init__.py").is_file():
97-
module_name = path.name
98-
names = _parse_all(path / "__init__.py")
99-
if names:
100-
for name in names:
101-
attr_map[name] = (module_name, name)
102-
103-
attr_map[module_name] = (module_name, None)
104-
105-
_ATTR_MAP = attr_map
106-
107-
108-
def _expose_module(module_name):
109-
if module_name in _EXPOSED_MODULES:
110-
return _import_module(module_name)
77+
def _resolve_extra(name):
78+
module_name, attr = _LAZY_LOADING_EXCEPTIONS[name]
11179
module = _import_module(module_name)
112-
names = getattr(module, "__all__", None)
113-
if names is None:
114-
names = [name for name in dir(module) if not name.startswith("_")]
115-
for name in names:
116-
globals()[name] = getattr(module, name)
117-
_EXPOSED_MODULES.add(module_name)
118-
return module
80+
value = module if attr is None else getattr(module, attr)
81+
# This binds the resolved object (The Class) to the global name
82+
globals()[name] = value
83+
return value
11984

12085

12186
def _setup():
@@ -127,8 +92,6 @@ def _setup():
12792
try:
12893
from .config import (
12994
rc,
130-
rc_matplotlib,
131-
rc_ultraplot,
13295
register_cmaps,
13396
register_colors,
13497
register_cycles,
@@ -147,29 +110,7 @@ def _setup():
147110
register_fonts(default=True)
148111

149112
rcsetup.VALIDATE_REGISTERED_CMAPS = True
150-
for key in (
151-
"cycle",
152-
"cmap.sequential",
153-
"cmap.diverging",
154-
"cmap.cyclic",
155-
"cmap.qualitative",
156-
):
157-
try:
158-
rc[key] = rc[key]
159-
except ValueError as err:
160-
warnings._warn_ultraplot(f"Invalid user rc file setting: {err}")
161-
rc[key] = "Greys"
162-
163113
rcsetup.VALIDATE_REGISTERED_COLORS = True
164-
for src in (rc_ultraplot, rc_matplotlib):
165-
for key in src:
166-
if "color" not in key:
167-
continue
168-
try:
169-
src[key] = src[key]
170-
except ValueError as err:
171-
warnings._warn_ultraplot(f"Invalid user rc file setting: {err}")
172-
src[key] = "black"
173114

174115
if rc["ultraplot.check_for_latest_version"]:
175116
from .utils import check_for_update
@@ -182,14 +123,6 @@ def _setup():
182123
_SETUP_RUNNING = False
183124

184125

185-
def _resolve_extra(name):
186-
module_name, attr = _LAZY_LOADING_EXCEPTIONS[name]
187-
module = _import_module(module_name)
188-
value = module if attr is None else getattr(module, attr)
189-
globals()[name] = value
190-
return value
191-
192-
193126
def _build_registry_map():
194127
global _REGISTRY_ATTRS
195128
if _REGISTRY_ATTRS is not None:
@@ -206,122 +139,124 @@ def _build_registry_map():
206139

207140
def _get_registry_attr(name):
208141
_build_registry_map()
209-
if not _REGISTRY_ATTRS:
210-
return None
211-
return _REGISTRY_ATTRS.get(name)
142+
return _REGISTRY_ATTRS.get(name) if _REGISTRY_ATTRS else None
212143

213144

214145
def _load_all():
215146
global _EAGER_DONE
216147
if _EAGER_DONE:
217-
try:
218-
return sorted(globals()["__all__"])
219-
except KeyError:
220-
pass
148+
return sorted(globals().get("__all__", []))
221149
_EAGER_DONE = True
222150
_setup()
223-
from .internals.benchmarks import _benchmark
224-
225151
_discover_modules()
226152
names = set(_ATTR_MAP.keys())
227-
228153
for name in names:
229154
try:
230155
__getattr__(name)
231156
except AttributeError:
232157
pass
233-
234158
names.update(_LAZY_LOADING_EXCEPTIONS.keys())
235-
with _benchmark("registries"):
236-
_build_registry_map()
159+
_build_registry_map()
237160
if _REGISTRY_ATTRS:
238161
names.update(_REGISTRY_ATTRS)
239162
names.update(
240163
{"__version__", "version", "name", "setup", "pyplot", "cartopy", "basemap"}
241164
)
242-
_EAGER_DONE = True
243165
return sorted(names)
244166

245167

246-
def _get_rc_eager():
247-
try:
248-
from .config import rc
249-
except Exception:
250-
return False
251-
try:
252-
return bool(rc["ultraplot.eager_import"])
253-
except Exception:
254-
return False
168+
def _discover_modules():
169+
global _ATTR_MAP
170+
if _ATTR_MAP is not None:
171+
return
255172

173+
attr_map = {}
174+
base = Path(__file__).resolve().parent
256175

257-
def _maybe_eager_import():
258-
if _EAGER_DONE:
259-
return
260-
if _get_rc_eager():
261-
_load_all()
176+
# PROTECT 'figure' from auto-discovery
177+
# We must explicitly ignore the file 'figure.py' so it doesn't
178+
# populate the attribute map as a module.
179+
protected = set(_LAZY_LOADING_EXCEPTIONS.keys())
180+
protected.add("figure")
262181

182+
for path in base.glob("*.py"):
183+
if path.name.startswith("_") or path.name == "setup.py":
184+
continue
185+
module_name = path.stem
263186

264-
def setup(*, eager=None):
265-
"""
266-
Initialize ultraplot and optionally import the public API eagerly.
267-
"""
268-
_setup()
269-
if eager is None:
270-
eager = _get_rc_eager()
271-
if eager:
272-
_load_all()
187+
# If the filename is 'figure', don't let it be an attribute
188+
if module_name in protected:
189+
continue
190+
191+
names = _parse_all(path)
192+
if names:
193+
for name in names:
194+
if name not in protected:
195+
attr_map[name] = (module_name, name)
196+
197+
if module_name not in attr_map:
198+
attr_map[module_name] = (module_name, None)
199+
200+
for path in base.iterdir():
201+
if not path.is_dir() or path.name.startswith("_") or path.name == "tests":
202+
continue
203+
module_name = path.name
204+
if module_name in protected:
205+
continue
206+
207+
if (path / "__init__.py").is_file():
208+
names = _parse_all(path / "__init__.py")
209+
if names:
210+
for name in names:
211+
if name not in protected:
212+
attr_map[name] = (module_name, name)
213+
attr_map[module_name] = (module_name, None)
214+
215+
# Hard force-remove figure from discovery map
216+
attr_map.pop("figure", None)
217+
_ATTR_MAP = attr_map
273218

274219

275220
def __getattr__(name):
221+
# If the name is already in globals, return it immediately
222+
# (Prevents re-running logic for already loaded attributes)
223+
if name in globals():
224+
return globals()[name]
225+
276226
if name == "pytest_plugins":
277227
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
228+
229+
# Priority 1: Check Explicit Exceptions FIRST (This catches 'figure')
230+
if name in _LAZY_LOADING_EXCEPTIONS:
231+
_setup()
232+
return _resolve_extra(name)
233+
234+
# Priority 2: Core metadata
278235
if name in {"__version__", "version", "name", "__all__"}:
279236
if name == "__all__":
280-
value = _load_all()
281-
globals()["__all__"] = value
282-
return value
283-
return globals()[name]
237+
val = _load_all()
238+
globals()["__all__"] = val
239+
return val
240+
return globals().get(name)
284241

242+
# Priority 3: External dependencies
285243
if name == "pyplot":
286-
import matplotlib.pyplot as pyplot
287-
288-
globals()[name] = pyplot
289-
return pyplot
290-
if name == "cartopy":
291-
try:
292-
import cartopy
293-
except ImportError as err:
294-
raise AttributeError(
295-
f"module {__name__!r} has no attribute {name!r}"
296-
) from err
297-
globals()[name] = cartopy
298-
return cartopy
299-
if name == "basemap":
300-
try:
301-
from mpl_toolkits import basemap
302-
except ImportError as err:
303-
raise AttributeError(
304-
f"module {__name__!r} has no attribute {name!r}"
305-
) from err
306-
globals()[name] = basemap
307-
return basemap
244+
import matplotlib.pyplot as plt
308245

309-
if name in _LAZY_LOADING_EXCEPTIONS:
310-
_setup()
311-
_maybe_eager_import()
312-
return _resolve_extra(name)
246+
globals()[name] = plt
247+
return plt
313248

249+
# Priority 4: Automated discovery
314250
_discover_modules()
315251
if _ATTR_MAP and name in _ATTR_MAP:
316252
module_name, attr_name = _ATTR_MAP[name]
317253
_setup()
318-
_maybe_eager_import()
319-
320254
module = _import_module(module_name)
321255
value = getattr(module, attr_name) if attr_name else module
322256
globals()[name] = value
323257
return value
324258

259+
# Priority 5: Registry (Capital names)
325260
if name[:1].isupper():
326261
value = _get_registry_attr(name)
327262
if value is not None:
@@ -338,3 +273,21 @@ def __dir__():
338273
names.update(_ATTR_MAP)
339274
names.update(_LAZY_LOADING_EXCEPTIONS)
340275
return sorted(names)
276+
277+
278+
# Prevent "import ultraplot.figure" from clobbering the top-level callable.
279+
import sys
280+
import types
281+
282+
283+
class _UltraPlotModule(types.ModuleType):
284+
def __setattr__(self, name, value):
285+
if name == "figure" and isinstance(value, types.ModuleType):
286+
super().__setattr__("_figure_module", value)
287+
return
288+
super().__setattr__(name, value)
289+
290+
291+
_module = sys.modules.get(__name__)
292+
if _module is not None and not isinstance(_module, _UltraPlotModule):
293+
_module.__class__ = _UltraPlotModule

0 commit comments

Comments
 (0)