@@ -165,20 +165,16 @@ def fit(self, X: ArrayLike, mixture: MixtureModel[DType]) -> MixtureModel[DType]
165165
166166 X = np .asarray (X , dtype = mixture .dtype )
167167 copied_mixture = copy (mixture ) # Copy to avoid modifying the original object
168- removed_indices : list [int ] = []
169168 state = PipelineState (X , None , None , copied_mixture , None )
170169
171170 while True :
172171 # Updating the state before starting an iteration
173172 state .prev_mixture = copy (state .curr_mixture )
174- # Update responsibility matrix H
175- if state .H is not None :
176- state .H = np .delete (state .H , removed_indices , axis = 1 )
177173
178174 # Performing steps
179175 for step in self .steps :
180176 result_state = step .run (state )
181- step . clear_after_prune ( removed_indices )
177+ # Log the error state before exiting
182178 if result_state .error :
183179 if len (self .history ) > 0 :
184180 self .history [- 1 ].error = result_state .error
@@ -219,11 +215,18 @@ def fit(self, X: ArrayLike, mixture: MixtureModel[DType]) -> MixtureModel[DType]
219215 # Pruning
220216 for pruner in self .pruners :
221217 state , removed_components_indices = pruner .prune (state )
222- if len (removed_components_indices ) != 0 :
223- removed_indices .extend (removed_components_indices )
224- # Save iteration record
225- self .history .save_record (
226- IterationRecord (self .history ._counter , state .curr_mixture , state .X , state .H , self .pruners , state .error )
218+ if removed_components_indices :
219+ # Update optimization blocks in steps to drop blocks associated with removed components
220+ for step in self .steps :
221+ step .clear_after_prune (removed_components_indices )
222+
223+ # Update responsibility matrix H to match the new mixture size
224+ if state .H is not None :
225+ state .H = np .delete (state .H , removed_components_indices , axis = 1 )
226+
227+ # Log
228+ self .logger .log (
229+ IterationRecord (self .logger ._counter , state .curr_mixture , state .X , state .H , self .pruners , state .error )
227230 )
228231
229232 # Checking stopping criteria
0 commit comments