@@ -87,6 +87,122 @@ def can_handle_model(cls, model: Any) -> bool:
8787 """
8888 return isinstance (model , sklearn .base .BaseEstimator )
8989
90+ @classmethod
91+ def trim_flow_name (
92+ cls ,
93+ long_name : str ,
94+ extra_trim_length : int = 100 ,
95+ _outer : bool = True
96+ ) -> str :
97+ """ Shorten generated sklearn flow name to at most `max_length` characters.
98+
99+ Flows are assumed to have the following naming structure:
100+ (model_selection)? (pipeline)? (steps)+
101+ and will be shortened to:
102+ sklearn.(selection.)?(pipeline.)?(steps)+
103+ e.g. (white spaces and newlines added for readability)
104+ sklearn.pipeline.Pipeline(
105+ columntransformer=sklearn.compose._column_transformer.ColumnTransformer(
106+ numeric=sklearn.pipeline.Pipeline(
107+ imputer=sklearn.preprocessing.imputation.Imputer,
108+ standardscaler=sklearn.preprocessing.data.StandardScaler),
109+ nominal=sklearn.pipeline.Pipeline(
110+ simpleimputer=sklearn.impute.SimpleImputer,
111+ onehotencoder=sklearn.preprocessing._encoders.OneHotEncoder)),
112+ variancethreshold=sklearn.feature_selection.variance_threshold.VarianceThreshold,
113+ svc=sklearn.svm.classes.SVC)
114+ ->
115+ sklearn.Pipeline(ColumnTransformer,VarianceThreshold,SVC)
116+
117+ Parameters
118+ ----------
119+ long_name : str
120+ The full flow name generated by the scikit-learn extension.
121+ extra_trim_length: int (default=100)
122+ If the trimmed name would exceed `extra_trim_length` characters, additional trimming
123+ of the short name is performed. This reduces the produced short name length.
124+ There is no guarantee the end result will not exceed `extra_trim_length`.
125+ _outer : bool (default=True)
126+ For internal use only. Specifies if the function is called recursively.
127+
128+ Returns
129+ -------
130+ str
131+
132+ """
133+ def remove_all_in_parentheses (string : str ) -> str :
134+ string , removals = re .subn (r"\([^()]*\)" , "" , string )
135+ while removals > 0 :
136+ string , removals = re .subn (r"\([^()]*\)" , "" , string )
137+ return string
138+
139+ # Generally, we want to trim all hyperparameters, the exception to that is for model
140+ # selection, as the `estimator` hyperparameter is very indicative of what is in the flow.
141+ # So we first trim name of the `estimator` specified in mode selection. For reference, in
142+ # the example below, we want to trim `sklearn.tree.tree.DecisionTreeClassifier`, and
143+ # keep it in the final trimmed flow name:
144+ # sklearn.pipeline.Pipeline(Imputer=sklearn.preprocessing.imputation.Imputer,
145+ # VarianceThreshold=sklearn.feature_selection.variance_threshold.VarianceThreshold,
146+ # Estimator=sklearn.model_selection._search.RandomizedSearchCV(estimator=
147+ # sklearn.tree.tree.DecisionTreeClassifier))
148+ if 'sklearn.model_selection' in long_name :
149+ start_index = long_name .index ('sklearn.model_selection' )
150+ estimator_start = (start_index
151+ + long_name [start_index :].index ('estimator=' )
152+ + len ('estimator=' ))
153+
154+ model_select_boilerplate = long_name [start_index :estimator_start ]
155+ # above is .g. "sklearn.model_selection._search.RandomizedSearchCV(estimator="
156+ model_selection_class = model_select_boilerplate .split ('(' )[0 ].split ('.' )[- 1 ]
157+
158+ # Now we want to also find and parse the `estimator`, for this we find the closing
159+ # parenthesis to the model selection technique:
160+ closing_parenthesis_expected = 1
161+ for i , char in enumerate (long_name [estimator_start :], start = estimator_start ):
162+ if char == '(' :
163+ closing_parenthesis_expected += 1
164+ if char == ')' :
165+ closing_parenthesis_expected -= 1
166+ if closing_parenthesis_expected == 0 :
167+ break
168+
169+ model_select_pipeline = long_name [estimator_start :i ]
170+ trimmed_pipeline = cls .trim_flow_name (model_select_pipeline , _outer = False )
171+ _ , trimmed_pipeline = trimmed_pipeline .split ('.' , maxsplit = 1 ) # trim module prefix
172+ model_select_short = "sklearn.{}[{}]" .format (model_selection_class , trimmed_pipeline )
173+ name = long_name [:start_index ] + model_select_short + long_name [i + 1 :]
174+ else :
175+ name = long_name
176+
177+ module_name = long_name .split ('.' )[0 ]
178+ short_name = module_name + '.{}'
179+
180+ if name .startswith ('sklearn.pipeline' ):
181+ full_pipeline_class , pipeline = name [:- 1 ].split ('(' , maxsplit = 1 )
182+ pipeline_class = full_pipeline_class .split ('.' )[- 1 ]
183+ # We don't want nested pipelines in the short name, so we trim all complicated
184+ # subcomponents, i.e. those with parentheses:
185+ pipeline = remove_all_in_parentheses (pipeline )
186+
187+ # then the pipeline steps are formatted e.g.:
188+ # step1name=sklearn.submodule.ClassName,step2name...
189+ components = [component .split ('.' )[- 1 ] for component in pipeline .split (',' )]
190+ pipeline = "{}({})" .format (pipeline_class , ',' .join (components ))
191+ if len (short_name .format (pipeline )) > extra_trim_length :
192+ pipeline = "{}(...,{})" .format (pipeline_class , components [- 1 ])
193+ else :
194+ # Just a simple component: e.g. sklearn.tree.DecisionTreeClassifier
195+ pipeline = remove_all_in_parentheses (name ).split ('.' )[- 1 ]
196+
197+ if not _outer :
198+ # Anything from parenthesis in inner calls should not be culled, so we use brackets
199+ pipeline = pipeline .replace ('(' , '[' ).replace (')' , ']' )
200+ else :
201+ # Square brackets may be introduced with nested model_selection
202+ pipeline = pipeline .replace ('[' , '(' ).replace (']' , ')' )
203+
204+ return short_name .format (pipeline )
205+
90206 ################################################################################################
91207 # Methods for flow serialization and de-serialization
92208
@@ -402,6 +518,7 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
402518 name = '%s(%s)' % (class_name , sub_components_names [1 :])
403519 else :
404520 name = class_name
521+ short_name = SklearnExtension .trim_flow_name (name )
405522
406523 # Get the external versions of all sub-components
407524 external_version = self ._get_external_version_string (model , subcomponents )
@@ -419,6 +536,7 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
419536 sklearn_version_formatted = sklearn_version .replace ('==' , '_' )
420537 flow = OpenMLFlow (name = name ,
421538 class_name = class_name ,
539+ custom_name = short_name ,
422540 description = 'Automatically created scikit-learn flow.' ,
423541 model = model ,
424542 components = subcomponents ,
0 commit comments