Skip to content

Commit 90a0687

Browse files
committed
fix(pipeline): correct state update logic during pruning
1 parent a0271ad commit 90a0687

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

rework_pysatl_mpest/estimators/iterative/pipeline.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,16 @@ def fit(self, X: ArrayLike, mixture: MixtureModel[DType]) -> MixtureModel[DType]
165165

166166
X = np.asarray(X, dtype=mixture.dtype)
167167
copied_mixture = copy(mixture) # Copy to avoid modifying the original object
168-
removed_indices: list[int] = []
169168
state = PipelineState(X, None, None, copied_mixture, None)
170169

171170
while True:
172171
# Updating the state before starting an iteration
173172
state.prev_mixture = copy(state.curr_mixture)
174-
# Update responsibility matrix H
175-
if state.H is not None:
176-
state.H = np.delete(state.H, removed_indices, axis=1)
177173

178174
# Performing steps
179175
for step in self.steps:
180176
result_state = step.run(state)
181-
step.clear_after_prune(removed_indices)
177+
# Log the error state before exiting
182178
if result_state.error:
183179
if len(self.history) > 0:
184180
self.history[-1].error = result_state.error
@@ -219,11 +215,18 @@ def fit(self, X: ArrayLike, mixture: MixtureModel[DType]) -> MixtureModel[DType]
219215
# Pruning
220216
for pruner in self.pruners:
221217
state, removed_components_indices = pruner.prune(state)
222-
if len(removed_components_indices) != 0:
223-
removed_indices.extend(removed_components_indices)
224-
# Save iteration record
225-
self.history.save_record(
226-
IterationRecord(self.history._counter, state.curr_mixture, state.X, state.H, self.pruners, state.error)
218+
if removed_components_indices:
219+
# Update optimization blocks in steps to drop blocks associated with removed components
220+
for step in self.steps:
221+
step.clear_after_prune(removed_components_indices)
222+
223+
# Update responsibility matrix H to match the new mixture size
224+
if state.H is not None:
225+
state.H = np.delete(state.H, removed_components_indices, axis=1)
226+
227+
# Log
228+
self.logger.log(
229+
IterationRecord(self.logger._counter, state.curr_mixture, state.X, state.H, self.pruners, state.error)
227230
)
228231

229232
# Checking stopping criteria

0 commit comments

Comments
 (0)