1010from copy import copy
1111
1212import numpy as np
13-
14- from rework_pysatl_mpest .core import MixtureModel
1513from rework_pysatl_mpest .estimators import ECM
16- from rework_pysatl_mpest .estimators .iterative import ExpectationStep , PipelineState , OptimizationBlock , \
17- MaximizationStrategy , MaximizationStep
14+ from rework_pysatl_mpest .estimators .iterative import (
15+ ExpectationStep ,
16+ MaximizationStep ,
17+ MaximizationStrategy ,
18+ OptimizationBlock ,
19+ PipelineState ,
20+ )
1821from rework_pysatl_mpest .estimators .iterative .breakpointers import StepBreakpointer
1922from rework_pysatl_mpest .optimizers import ScipyNelderMead
20- from .common import Benchmark , DTYPES_MAP , SAMPLE_SIZES , get_components , DISTRIBUTIONS , RNG_GENERATOR
23+
24+ from .common import DISTRIBUTIONS , DTYPES_MAP , RNG_GENERATOR , SAMPLE_SIZES , Benchmark , LibAdapter , get_components
2125
2226
2327class StepOverhead (Benchmark ):
@@ -34,8 +38,8 @@ class StepOverhead(Benchmark):
3438 DISTRIBUTIONS , # dist_name
3539 [2 ], # n_components
3640 SAMPLE_SIZES , # n_samples
37- list (DTYPES_MAP .keys ()), # dtype_name
38- [True , False ] # is_soft
41+ list (DTYPES_MAP .keys ()), # dtype_name
42+ [True , False ], # is_soft
3943 )
4044 param_names = ["dist_name" , "n_components" , "n_samples" , "dtype_name" , "is_soft" ]
4145
@@ -52,17 +56,15 @@ def setup(self, dist_name, n_components, n_samples, dtype_name, is_soft):
5256 comp .fix_param ("shape" )
5357 comp .fix_param ("loc" )
5458
55- self .mix_analytical = MixtureModel ( self .comps_analytical , dtype = dtype )
59+ self .mix_analytical = LibAdapter . create_mixture ( components = self .comps_analytical , dtype = dtype )
5660 self .X_analytical = self .mix_analytical .generate (n_samples )
5761
5862 # --- Pipeline Components ---
5963 self .e_step = ExpectationStep (is_soft = is_soft )
6064
6165 # Setup States
6266 # 1. State ready for E-step
63- self .state_analytical_for_E = PipelineState (
64- self .X_analytical , None , None , copy (self .mix_analytical ), None
65- )
67+ self .state_analytical_for_E = PipelineState (self .X_analytical , None , None , copy (self .mix_analytical ), None )
6668
6769 # 2. State ready for M-step (Pre-calculate H)
6870 self .state_analytical_for_M = self .e_step .run (
@@ -103,11 +105,11 @@ class ECMAnalyticalCleanWithStepBreakpointer(Benchmark):
103105 """
104106
105107 params = (
106- ["Normal" , "Exponential" , "Pareto" , "Weibull" ], # dist_name
107- [2 ], # n_components
108- [5 ], # max_steps
109- SAMPLE_SIZES , # n_samples
110- list (DTYPES_MAP .keys ()) # dtype_name
108+ ["Normal" , "Exponential" , "Pareto" , "Weibull" ], # dist_name
109+ [2 ], # n_components
110+ [5 ], # max_steps
111+ SAMPLE_SIZES , # n_samples
112+ list (DTYPES_MAP .keys ()), # dtype_name
111113 )
112114 param_names = ["dist_name" , "n_components" , "max_steps" , "n_samples" , "dtype_name" ]
113115
@@ -120,24 +122,22 @@ def setup(self, dist_name, n_components, max_steps, n_samples, dtype_name):
120122
121123 dtype = DTYPES_MAP [dtype_name ]
122124 true_comps = get_components (dist_name , dtype , n_components )
123- self .X = MixtureModel ( true_comps ).generate (n_samples )
125+ self .X = LibAdapter . create_mixture ( components = true_comps ).generate (n_samples )
124126
125127 start_comps = copy (true_comps )
126128 for comp in start_comps :
127- new_params = np .asarray (comp .get_params_vector (comp .params ), dtype = dtype ) + np .ones (len (comp .params ), dtype = dtype )
129+ new_params = np .asarray (comp .get_params_vector (comp .params ), dtype = dtype ) + np .ones (
130+ len (comp .params ), dtype = dtype
131+ )
128132 comp .set_params_from_vector (comp .params , new_params )
129133 if dist_name == "Weibull" :
130134 comp .fix_param ("shape" )
131135 comp .fix_param ("loc" )
132136
133- self .start_mixture = MixtureModel ( start_comps , dtype = dtype )
137+ self .start_mixture = LibAdapter . create_mixture ( components = start_comps , dtype = dtype )
134138
135139 # Configure ECM to run for a fixed small number of steps to measure throughput
136- self .ecm = ECM (
137- breakpointers = [StepBreakpointer (max_steps = max_steps )],
138- pruners = [],
139- optimizer = ScipyNelderMead ()
140- )
140+ self .ecm = ECM (breakpointers = [StepBreakpointer (max_steps = max_steps )], pruners = [], optimizer = ScipyNelderMead ())
141141
142142 def time_fit (self , dist_name , n_components , max_steps , n_samples , dtype_name ):
143143 self .ecm .fit (self .X , self .start_mixture )
@@ -152,10 +152,10 @@ class ECMAnalyticalOverflow(Benchmark):
152152 """
153153
154154 params = (
155- ["Normal" , "Exponential" , "Pareto" , "Weibull" ], # dist_name
155+ ["Normal" , "Exponential" , "Pareto" , "Weibull" ], # dist_name
156156 [2 ], # n_components
157- SAMPLE_SIZES , # n_samples
158- ["float16" ] # dtype_name
157+ SAMPLE_SIZES , # n_samples
158+ ["float16" ], # dtype_name
159159 )
160160 param_names = ["dist_name" , "n_components" , "n_samples" , "dtype_name" ]
161161 timeout = 300.0
@@ -177,14 +177,10 @@ def setup(self, dist_name, n_components, n_samples, dtype_name):
177177 comp .fix_param ("shape" )
178178 comp .fix_param ("loc" )
179179
180- self .start_mix = MixtureModel ( start_comps , dtype = dtype )
180+ self .start_mix = LibAdapter . create_mixture ( components = start_comps , dtype = dtype )
181181
182182 # Run only 1 step to trigger the error immediately and measure recovery overhead
183- self .ecm = ECM (
184- breakpointers = [StepBreakpointer (max_steps = 1 )],
185- pruners = [],
186- optimizer = ScipyNelderMead ()
187- )
183+ self .ecm = ECM (breakpointers = [StepBreakpointer (max_steps = 1 )], pruners = [], optimizer = ScipyNelderMead ())
188184
189185 def time_fit_restart (self , dist_name , n_components , n_samples , dtype_name ):
190186 with warnings .catch_warnings ():
0 commit comments