|
15 | 15 | import logging |
16 | 16 | import copy |
17 | 17 | import math |
| 18 | +import inspect |
18 | 19 |
|
19 | 20 | import numpy as np |
20 | 21 | from scipy.cluster.hierarchy import fcluster |
@@ -1480,6 +1481,45 @@ def __init__( |
1480 | 1481 | self.input_as_observation_matrix = input_as_observation_matrix |
1481 | 1482 | self.labels_ = None |
1482 | 1483 |
|
| 1484 | + # ---- new: robust get_params ---- |
| 1485 | + def get_params(self, deep=True): |
| 1486 | + """ |
| 1487 | + Return params but gracefully handle the case where an instance |
| 1488 | + (e.g., loaded from an old pickle) is missing attributes. |
| 1489 | + """ |
| 1490 | + params = {} |
| 1491 | + sig = inspect.signature(self.__class__.__init__) |
| 1492 | + for name, param in sig.parameters.items(): |
| 1493 | + if name == "self": |
| 1494 | + continue |
| 1495 | + # prefer the runtime attribute if present, otherwise the __init__ default |
| 1496 | + if hasattr(self, name): |
| 1497 | + params[name] = getattr(self, name) |
| 1498 | + else: |
| 1499 | + params[name] = param.default if param.default is not inspect._empty else None |
| 1500 | + |
| 1501 | + if deep: |
| 1502 | + # sklearn convention: include nested estimator params with __ separator |
| 1503 | + for key, val in list(params.items()): |
| 1504 | + if hasattr(val, "get_params"): |
| 1505 | + for subk, subv in val.get_params(deep=True).items(): |
| 1506 | + params[f"{key}__{subk}"] = subv |
| 1507 | + return params |
| 1508 | + |
| 1509 | + # ---- new: ensure defaults after unpickling ---- |
| 1510 | + def __setstate__(self, state): |
| 1511 | + """ |
| 1512 | + Called on unpickle. Populate any missing ctor attrs with their defaults. |
| 1513 | + """ |
| 1514 | + self.__dict__.update(state) |
| 1515 | + sig = inspect.signature(self.__class__.__init__) |
| 1516 | + for name, param in sig.parameters.items(): |
| 1517 | + if name == "self": |
| 1518 | + continue |
| 1519 | + if not hasattr(self, name): |
| 1520 | + default = param.default if param.default is not inspect._empty else None |
| 1521 | + setattr(self, name, default) |
| 1522 | + |
1483 | 1523 | def fit(self, X: np.ndarray) -> np.ndarray: |
1484 | 1524 | """ |
1485 | 1525 | Fit the estimator on input data. The results are stored in self.labels_. |
|
0 commit comments