File tree Expand file tree Collapse file tree 2 files changed +6
-196
lines changed
src/midst_toolkit/models/clavaddpm Expand file tree Collapse file tree 2 files changed +6
-196
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -505,7 +505,6 @@ def train_classifier(
505505 cluster_col : str = "cluster" ,
506506 dim_t : int = 128 ,
507507 lr : float = 0.0001 ,
508- pre_trained_classifier : Classifier | None = None ,
509508) -> Classifier :
510509 T = Transformations (** T_dict )
511510 dataset , label_encoders , column_orders = make_dataset_from_df (
@@ -533,17 +532,12 @@ def train_classifier(
533532 if model_params ["is_y_cond" ] == "concat" :
534533 num_numerical_features -= 1
535534
536- if pre_trained_classifier is None :
537- classifier = Classifier (
538- d_in = num_numerical_features ,
539- d_out = int (max (df [cluster_col ].values ) + 1 ),
540- dim_t = dim_t ,
541- hidden_sizes = d_layers ,
542- ).to (device )
543- else :
544- classifier = pre_trained_classifier
545- classifier .to (device )
546-
535+ classifier = Classifier (
536+ d_in = num_numerical_features ,
537+ d_out = int (max (df [cluster_col ].values ) + 1 ),
538+ dim_t = dim_t ,
539+ hidden_sizes = d_layers ,
540+ ).to (device )
547541 classifier_optimizer = optim .AdamW (classifier .parameters (), lr = lr )
548542
549543 empty_diffusion = GaussianMultinomialDiffusion (
You can’t perform that action at this time.
0 commit comments