@@ -117,6 +117,19 @@ def __init__(self, number_of_clauses, T, s, patch_dim, boost_true_positive_feedb
117117 else :
118118 self .s_range = s
119119
120+ def __getstate__ (self ):
121+ state = self .__dict__ .copy ()
122+ state ['mc_ctm_state' ] = self .get_state ()
123+ del state ['mc_ctm' ]
124+ if 'encoded_X' in state :
125+ del state ['encoded_X' ]
126+ return state
127+
128+ def __setstate__ (self , state ):
129+ self .__dict__ .update (state )
130+ self .mc_ctm = _lib .CreateMultiClassTsetlinMachine (self .number_of_classes , self .number_of_clauses , self .number_of_features , self .number_of_patches , self .number_of_ta_chunks , self .number_of_state_bits , self .T , self .s , self .s_range , self .boost_true_positive_feedback , self .weighted_clauses , self .clause_drop_p , self .literal_drop_p )
131+ self .set_state (state ['mc_ctm_state' ])
132+
120133 def __del__ (self ):
121134 if self .mc_ctm != None :
122135 _lib .mc_tm_destroy (self .mc_ctm )
@@ -237,6 +250,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
237250 else :
238251 self .s_range = s
239252
253+ def __getstate__ (self ):
254+ state = self .__dict__ .copy ()
255+ state ['mc_tm_state' ] = self .get_state ()
256+ del state ['mc_tm' ]
257+ if 'encoded_X' in state :
258+ del state ['encoded_X' ]
259+ return state
260+
261+ def __setstate__ (self , state ):
262+ self .__dict__ .update (state )
263+ self .mc_tm = _lib .CreateMultiClassTsetlinMachine (self .number_of_classes , self .number_of_clauses , self .number_of_features , 1 , self .number_of_ta_chunks , self .number_of_state_bits , self .T , self .s , self .s_range , self .boost_true_positive_feedback , self .weighted_clauses , self .clause_drop_p , self .literal_drop_p )
264+ self .set_state (state ['mc_tm_state' ])
265+
240266 def __del__ (self ):
241267 if self .mc_tm != None :
242268 _lib .mc_tm_destroy (self .mc_tm )
@@ -348,6 +374,19 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
348374 else :
349375 self .s_range = s
350376
377+ def __getstate__ (self ):
378+ state = self .__dict__ .copy ()
379+ state ['rtm_state' ] = self .get_state ()
380+ del state ['rtm' ]
381+ if 'encoded_X' in state :
382+ del state ['encoded_X' ]
383+ return state
384+
385+ def __setstate__ (self , state ):
386+ self .__dict__ .update (state )
387+ self .rtm = _lib .CreateTsetlinMachine (self .number_of_clauses , self .number_of_features , 1 , self .number_of_ta_chunks , self .number_of_state_bits , self .T , self .s , self .s_range , self .boost_true_positive_feedback , self .weighted_clauses )
388+ self .set_state (state ['rtm_state' ])
389+
351390 def __del__ (self ):
352391 if self .rtm != None :
353392 _lib .tm_destroy (self .rtm )
0 commit comments