@@ -383,12 +383,10 @@ def fit(
383383 self .feature_names = feature_names
384384
385385 if sample_weight is not None :
386- # Choose appropriate expansion depending on the library type
387- lib = self .feature_library .__class__ .__name__
388- if lib in ("WeakPDELibrary" , "WeightedWeakPDELibrary" ):
389- sample_weight = _expand_weak_sample_weights (sample_weight , x , self .feature_library )
390- else :
391- sample_weight = _expand_sample_weights (sample_weight , x )
386+ mode = "weak" if "Weak" in self .feature_library .__class__ .__name__ else "standard"
387+ sample_weight = _expand_sample_weights (
388+ sample_weight , x , feature_library = self .feature_library , mode = mode
389+ )
392390
393391 steps = [
394392 ("features" , self .feature_library ),
@@ -947,137 +945,78 @@ def comprehend_and_validate(arr, t):
947945 u = [comprehend_and_validate (ui , ti ) for ui , ti in _zip_like_sequence (u , t )]
948946 return x , x_dot , u
949947
950- def _assert_sample_weights (sample_weight , trajectories ):
951- """Validate per-trajectory sample_weight input .
948+ def _expand_sample_weights (sample_weight , trajectories , feature_library = None , mode = "standard" ):
949+ """Expand per-trajectory sample weights for estimators or weak-form libraries .
952950
953- Requirements enforced:
954- - `sample_weight` must be a Python sequence (e.g., list or tuple), not a scalar
955- and not a single numpy array.
956- - Length of the sequence must equal the number of trajectories.
957- - Each element must be array-like with first axis length equal to the
958- corresponding trajectory's `n_time`.
959- - If an element is 2D, its second dimension must be 1 (broadcast) or equal
960- to the trajectory coordinate count `n_coord`.
951+ Parameters
952+ ----------
953+ sample_weight : sequence of array-like or None
954+ Per-trajectory sample weights. Each element corresponds to one trajectory.
955+ trajectories : sequence
956+ Sequence of trajectory objects, each with attributes `n_time` and `n_coord`.
957+ feature_library : object, optional
958+ Library instance, required when mode='weak'.
959+ mode : {'standard', 'weak'}, default='standard'
960+ Expansion mode:
961+ - 'standard' : Concatenate weights per sample or per coordinate.
962+ - 'weak' : Expand weights for weak-form (integral) test functions.
961963
962964 Returns
963965 -------
964- validated_list : list of numpy arrays
965- The per-trajectory arrays (not concatenated) .
966+ np.ndarray or None
967+ Concatenated and expanded sample weights, or None if no weights are given .
966968 """
967-
968969 if sample_weight is None :
969970 return None
970971
971972 if not (isinstance (sample_weight , Sequence ) and not isinstance (sample_weight , np .ndarray )):
972- raise ValueError (
973- "sample_weight must be a sequence (e.g. list or tuple) with one entry per trajectory. "
974- "Do not pass a scalar or a single numpy array here."
975- )
973+ raise ValueError ("sample_weight must be a list or tuple, not a scalar or numpy array." )
976974
977975 if len (sample_weight ) != len (trajectories ):
978- raise ValueError (
979- f"When passing a sequence of sample_weight, its length ({ len (sample_weight )} ) must equal the number of trajectories ({ len (trajectories )} )"
980- )
976+ raise ValueError ("sample_weight length must match number of trajectories." )
981977
978+ # --- Validate shape consistency ---
982979 validated = []
983980 for sw , traj in zip (sample_weight , trajectories ):
984- a = np .asarray (sw )
985- if a .ndim == 0 :
986- validated .append (a )
981+ arr = np .asarray (sw )
982+ if arr .ndim == 0 :
983+ validated .append (arr )
987984 continue
988- if a .shape [0 ] != traj .n_time :
989- raise ValueError (
990- f"sample_weight entry length ({ a .shape [0 ]} ) does not match trajectory length ({ traj .n_time } )"
991- )
992- if a .ndim == 2 and a .shape [1 ] not in (1 , traj .n_coord ):
993- raise ValueError (
994- f"sample_weight entry second dimension ({ a .shape [1 ]} ) must be 1 or equal to the number of coordinates ({ traj .n_coord } )"
995- )
996- validated .append (a )
997- return validated
998-
999-
1000- def _expand_sample_weights (sample_weight , trajectories ):
1001- """Concatenate per-trajectory sample weights into final array for estimators.
1002-
1003- Expects `sample_weight` to be a sequence (validated by _assert_sample_weights).
1004- Returns either:
1005- - 1D array of shape (N_total_samples,) when weights are per-sample scalars, or
1006- - 2D array of shape (N_total_samples, n_coord) when per-sample-per-coordinate weights
1007- are provided (or when 1D weights are promoted to match n_coord).
1008- """
1009-
1010- sw_list = _assert_sample_weights (sample_weight , trajectories )
1011-
1012- if sw_list is None :
1013- return None
1014-
1015- # Determine common coordinate count
985+ if arr .shape [0 ] != traj .n_time :
986+ raise ValueError ("sample_weight entry length does not match trajectory length." )
987+ if arr .ndim == 2 and arr .shape [1 ] not in (1 , traj .n_coord ):
988+ raise ValueError ("sample_weight 2D second dim must be 1 or equal to n_coord." )
989+ validated .append (arr )
990+
991+ # --- Weak-form expansion ---
992+ if mode == "weak" :
993+ n_funcs = getattr (feature_library , "K" , 1 )
994+ if n_funcs is None :
995+ warnings .warn ("feature_library missing 'K'; assuming 1 test function." )
996+ n_funcs = 1
997+ return np .concatenate ([np .repeat (np .asarray (sw ), n_funcs , axis = 0 ) for sw in validated ])
998+
999+ # --- Standard expansion ---
10161000 n_coords = {int (t .n_coord ) for t in trajectories }
10171001 if len (n_coords ) != 1 :
1018- raise ValueError ("All trajectories must have the same number of coordinate components" )
1019- n_coord = next (iter (n_coords ))
1020-
1021- arrs = []
1022- dims = []
1023- for a in sw_list :
1024- a = np .asarray (a )
1025- if a .ndim == 1 :
1026- arrs .append (a )
1027- dims .append (1 )
1028- else : # a.ndim == 2
1029- if a .shape [1 ] == 1 :
1030- arrs .append (a [:, 0 ])
1031- dims .append (1 )
1032- else :
1033- arrs .append (a )
1034- dims .append (n_coord )
1035-
1036- # All 1D -> concatenate to 1D
1037- if all (d == 1 for d in dims ):
1038- return np .concatenate ([a .reshape (- 1 ) for a in arrs ], axis = 0 )
1039-
1040- # Otherwise promote 1D arrays to 2D and concatenate
1041- promoted = []
1042- for a , d in zip (arrs , dims ):
1043- if d == 1 :
1044- promoted .append (np .broadcast_to (a .reshape (- 1 , 1 ), (a .shape [0 ], n_coord )))
1045- else :
1046- promoted .append (a )
1047- return np .concatenate (promoted , axis = 0 )
1048-
1049-
1050- def _expand_weak_sample_weights (sample_weight , trajectories , feature_library ):
1051- """Expand sample weights for weak-form (integral) SINDy libraries.
1052-
1053- Each trajectory contributes multiple weak test functions (integrals).
1054- This expands the sample weights to match the number of weak test functions
1055- per trajectory, and concatenates across all trajectories.
1056-
1057- Returns
1058- -------
1059- np.ndarray
1060- Expanded weights with shape matching the number of weak test function
1061- evaluations across all trajectories.
1062- """
1063- sw_list = _assert_sample_weights (sample_weight , trajectories )
1064- if sw_list is None :
1065- return None
1066-
1067- # Number of test functions in the weak library
1068- n_test_funcs = getattr (feature_library , "K" , None )
1069- if n_test_funcs is None :
1070- warnings .warn (
1071- "Weak-form feature library did not define `n_test_functions`; "
1072- "assuming 1 weight per trajectory."
1073- )
1074- n_test_funcs = 1
1075-
1076- expanded = []
1077- for sw , traj in zip (sw_list , trajectories ):
1078- # Each trajectory contributes n_test_funcs weak equations
1079- sw = np .asarray (sw )
1080- # Expand weights by repeating for each weak test function
1081- sw_expanded = np .repeat (sw , n_test_funcs , axis = 0 )
1082- expanded .append (sw_expanded )
1002+ raise ValueError ("All trajectories must have the same n_coord." )
1003+ n_coord = n_coords .pop ()
1004+
1005+ processed = []
1006+ for arr in validated :
1007+ arr = np .asarray (arr )
1008+ if arr .ndim == 1 :
1009+ arr = arr .reshape (- 1 , 1 )
1010+ elif arr .ndim == 2 and arr .shape [1 ] == 1 :
1011+ pass # already correct shape
1012+ processed .append (arr )
1013+
1014+ # Promote to n_coord if any arrays have multiple coordinates
1015+ is_scalar_weight = all (a .shape [1 ] == 1 for a in processed )
1016+ if is_scalar_weight :
1017+ return np .concatenate ([a .ravel () for a in processed ])
1018+ expanded = [
1019+ np .broadcast_to (a , (a .shape [0 ], n_coord )) if a .shape [1 ] == 1 else a
1020+ for a in processed
1021+ ]
10831022 return np .concatenate (expanded , axis = 0 )
0 commit comments