Skip to content

Commit a4a88da

Browse files
committed
test fixes
1 parent 90eeefc commit a4a88da

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

tests/modules/test_training_feature_importances.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,27 @@
9191

9292

9393
def _get_available_models_by_type():
94-
"""Get all available models dynamically from the registry, grouped by ML type."""
95-
return {ml_type: list(Models.get_models_for_type(ml_type)) for ml_type in MLType}
94+
"""Get all available models dynamically from the registry, grouped by ML type.
95+
96+
Gracefully skips models whose factory functions fail to import
97+
(e.g. TabularNN models that require ``torch``, which may not be
98+
available on all platforms such as Windows CI).
99+
"""
100+
all_models = Models._config_factories.keys()
101+
models_by_type: dict[MLType, list[str]] = {ml_type: [] for ml_type in MLType}
102+
103+
for model_name in all_models:
104+
try:
105+
model_config = Models.get_config(model_name)
106+
for ml_type in MLType:
107+
if model_config.supports_ml_type(ml_type):
108+
models_by_type[ml_type].append(model_name)
109+
except Exception:
110+
# Skip models whose dependencies can't be loaded
111+
# (e.g. torch DLL failure on Windows)
112+
continue
113+
114+
return models_by_type
96115

97116

98117
def _generate_model_fi_params():

0 commit comments

Comments
 (0)