Skip to content

Commit a352d36

Browse files
committed
Removed some parts that should go with the next PR
1 parent a812665 commit a352d36

File tree

2 files changed

+6
-196
lines changed

2 files changed

+6
-196
lines changed

src/midst_toolkit/models/clavaddpm/fine_tuning_module.py

Lines changed: 0 additions & 184 deletions
This file was deleted.

src/midst_toolkit/models/clavaddpm/model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def train_classifier(
505505
cluster_col: str = "cluster",
506506
dim_t: int = 128,
507507
lr: float = 0.0001,
508-
pre_trained_classifier: Classifier | None = None,
509508
) -> Classifier:
510509
T = Transformations(**T_dict)
511510
dataset, label_encoders, column_orders = make_dataset_from_df(
@@ -533,17 +532,12 @@ def train_classifier(
533532
if model_params["is_y_cond"] == "concat":
534533
num_numerical_features -= 1
535534

536-
if pre_trained_classifier is None:
537-
classifier = Classifier(
538-
d_in=num_numerical_features,
539-
d_out=int(max(df[cluster_col].values) + 1),
540-
dim_t=dim_t,
541-
hidden_sizes=d_layers,
542-
).to(device)
543-
else:
544-
classifier = pre_trained_classifier
545-
classifier.to(device)
546-
535+
classifier = Classifier(
536+
d_in=num_numerical_features,
537+
d_out=int(max(df[cluster_col].values) + 1),
538+
dim_t=dim_t,
539+
hidden_sizes=d_layers,
540+
).to(device)
547541
classifier_optimizer = optim.AdamW(classifier.parameters(), lr=lr)
548542

549543
empty_diffusion = GaussianMultinomialDiffusion(

0 commit comments

Comments
 (0)