Skip to content

Commit ca14837

Browse files
authored
Merge pull request #52 from allenai/fastcluster_robust
Make fastcluster backwards compatible
2 parents 64fc6c1 + 29e0fd1 commit ca14837

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

s2and/model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import copy
1717
import math
18+
import inspect
1819

1920
import numpy as np
2021
from scipy.cluster.hierarchy import fcluster
@@ -1480,6 +1481,45 @@ def __init__(
14801481
self.input_as_observation_matrix = input_as_observation_matrix
14811482
self.labels_ = None
14821483

1484+
# ---- new: robust get_params ----
1485+
def get_params(self, deep=True):
1486+
"""
1487+
Return params but gracefully handle the case where an instance
1488+
(e.g., loaded from an old pickle) is missing attributes.
1489+
"""
1490+
params = {}
1491+
sig = inspect.signature(self.__class__.__init__)
1492+
for name, param in sig.parameters.items():
1493+
if name == "self":
1494+
continue
1495+
# prefer the runtime attribute if present, otherwise the __init__ default
1496+
if hasattr(self, name):
1497+
params[name] = getattr(self, name)
1498+
else:
1499+
params[name] = param.default if param.default is not inspect._empty else None
1500+
1501+
if deep:
1502+
# sklearn convention: include nested estimator params with __ separator
1503+
for key, val in list(params.items()):
1504+
if hasattr(val, "get_params"):
1505+
for subk, subv in val.get_params(deep=True).items():
1506+
params[f"{key}__{subk}"] = subv
1507+
return params
1508+
1509+
# ---- new: ensure defaults after unpickling ----
1510+
def __setstate__(self, state):
1511+
"""
1512+
Called on unpickle. Populate any missing ctor attrs with their defaults.
1513+
"""
1514+
self.__dict__.update(state)
1515+
sig = inspect.signature(self.__class__.__init__)
1516+
for name, param in sig.parameters.items():
1517+
if name == "self":
1518+
continue
1519+
if not hasattr(self, name):
1520+
default = param.default if param.default is not inspect._empty else None
1521+
setattr(self, name, default)
1522+
14831523
def fit(self, X: np.ndarray) -> np.ndarray:
14841524
"""
14851525
Fit the estimator on input data. The results are stored in self.labels_.

0 commit comments

Comments
 (0)