You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API
## What changes were proposed in this pull request?
Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting.
## How was this patch tested?
UT added.
Author: WeichenXu <[email protected]>
Closesapache#19627 from WeichenXu123/expose-model-list-py.
Copy file name to clipboardExpand all lines: python/pyspark/ml/param/shared.py
+24Lines changed: 24 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -655,6 +655,30 @@ def getParallelism(self):
655
655
returnself.getOrDefault(self.parallelism)
656
656
657
657
658
+
classHasCollectSubModels(Params):
659
+
"""
660
+
Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
661
+
"""
662
+
663
+
collectSubModels=Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean)
664
+
665
+
def__init__(self):
666
+
super(HasCollectSubModels, self).__init__()
667
+
self._setDefault(collectSubModels=False)
668
+
669
+
defsetCollectSubModels(self, value):
670
+
"""
671
+
Sets the value of :py:attr:`collectSubModels`.
672
+
"""
673
+
returnself._set(collectSubModels=value)
674
+
675
+
defgetCollectSubModels(self):
676
+
"""
677
+
Gets the value of collectSubModels or its default value.
678
+
"""
679
+
returnself.getOrDefault(self.collectSubModels)
680
+
681
+
658
682
classHasLoss(Params):
659
683
"""
660
684
Mixin for param loss: the loss function to be optimized.
0 commit comments