22Data-manipulation utilities.
33"""
44import re
5+ import types
6+ import warnings
57from collections import Counter
68from itertools import chain , count
79from typing import Callable , Union , List , Type
@@ -72,6 +74,12 @@ class SharedComputeValue:
7274 def __init__ (self , compute_shared , variable = None ):
7375 self .compute_shared = compute_shared
7476 self .variable = variable
77+ if compute_shared is not None \
78+ and not isinstance (compute_shared , (types .BuiltinFunctionType ,
79+ types .FunctionType )) \
80+ and not redefines_eq_and_hash (compute_shared ):
81+ warnings .warn (f"{ type (compute_shared ).__name__ } should define"
82+ f"__eq__ and __hash__ to be used for compute_shared" )
7583
7684 def __call__ (self , data , shared_data = None ):
7785 """Fallback if common parts are not passed."""
@@ -85,6 +93,14 @@ def compute(self, data, shared_data):
8593 Subclasses need to implement this function."""
8694 raise NotImplementedError
8795
96+ def __eq__ (self , other ):
97+ return type (self ) is type (other ) \
98+ and self .compute_shared == other .compute_shared \
99+ and self .variable == other .variable
100+
101+ def __hash__ (self ):
102+ return hash ((type (self ), self .compute_shared , self .variable ))
103+
88104
89105def vstack (arrays ):
90106 """vstack that supports sparse and dense arrays
@@ -307,3 +323,20 @@ def sanitized_name(name: str) -> str:
307323 if sanitized [0 ].isdigit ():
308324 sanitized = "_" + sanitized
309325 return sanitized
326+
327+
328+ def redefines_eq_and_hash (this ):
329+ """
330+ Check if the passed object (or class) redefines __eq__ and __hash__.
331+
332+ Args:
333+ this: class or object
334+ """
335+ if not isinstance (this , type ):
336+ this = type (this )
337+
338+ # if only __eq__ is defined, __hash__ is set to None
339+ if this .__hash__ is None :
340+ return False
341+
342+ return "__hash__" in this .__dict__ and "__eq__" in this .__dict__
0 commit comments