Skip to content

Commit 2252f1f

Browse files
committed
update subnet
1 parent 1d330e8 commit 2252f1f

File tree

3 files changed

+10
-19
lines changed

3 files changed

+10
-19
lines changed
-2.49 KB
Binary file not shown.

intro_example/plot_results.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@
5555
c2st_results = {k: None for k in models}
5656
mmd_results = {k: None for k in models}
5757
for m in models:
58-
cross_validation = []
59-
for _ in range(5):
60-
cross_validation.append(classifier_two_sample_test(kinematics_samples[m], approx_ground_truth))
61-
c2st_results[m] = np.mean(cross_validation)
58+
c2st_results[m] = classifier_two_sample_test(kinematics_samples[m], approx_ground_truth)
6259
mmd_results[m] = maximum_mean_discrepancy(kinematics_samples[m], approx_ground_truth)#.detach().cpu().item()
6360

6461
# %%

intro_example/train_models.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,31 @@
1010

1111
from keras.utils import clear_session
1212

13-
logging.getLogger("bayesflow").setLevel(logging.DEBUG)
1413
BASE = Path(__file__).resolve().parent
1514
EPOCHS = 1000
1615
BATCH_SIZE = 128
1716
NUM_SAMPLES_INFERENCE = 1000
1817
MODELS = {
19-
"flow_matching": (bf.networks.FlowMatching, {"subnet": "mlp"}),
20-
"cot_flow_matching": (bf.networks.FlowMatching, {"use_optimal_transport": True, "subnet": "mlp"}),
21-
"consistency_model": (bf.networks.ConsistencyModel, {"total_steps": EPOCHS*BATCH_SIZE, "subnet": "mlp"}),
22-
"stable_consistency_model": (bf.experimental.StableConsistencyModel, {"subnet": "mlp"}),
18+
"flow_matching": (bf.networks.FlowMatching, {}),
19+
"cot_flow_matching": (bf.networks.FlowMatching, {"use_optimal_transport": True}),
20+
"consistency_model": (bf.networks.ConsistencyModel, {"total_steps": EPOCHS*BATCH_SIZE}),
21+
"stable_consistency_model": (bf.experimental.StableConsistencyModel, {}),
2322
"diffusion_edm_vp": (bf.networks.DiffusionModel, {
24-
"subnet": "mlp",
25-
"noise_schedule": "edm",
23+
"noise_schedule": "edm",
2624
"prediction_type": "F",
2725
"schedule_kwargs": {"variance_type": "preserving"}}),
2826
"diffusion_edm_ve": (bf.networks.DiffusionModel, {
29-
"subnet": "mlp",
30-
"noise_schedule": "edm",
27+
"noise_schedule": "edm",
3128
"prediction_type": "F",
3229
"schedule_kwargs": {"variance_type": "exploding"}}),
3330
"diffusion_cosine_F": (bf.networks.DiffusionModel, {
34-
"subnet": "mlp",
35-
"noise_schedule": "cosine",
31+
"noise_schedule": "cosine",
3632
"prediction_type": "F", }),
3733
"diffusion_cosine_v": (bf.networks.DiffusionModel, {
38-
"subnet": "mlp",
39-
"noise_schedule": "cosine",
34+
"noise_schedule": "cosine",
4035
"prediction_type": "velocity"}),
4136
"diffusion_cosine_noise": (bf.networks.DiffusionModel, {
42-
"subnet": "mlp",
43-
"noise_schedule": "cosine",
37+
"noise_schedule": "cosine",
4438
"prediction_type": "noise"}),
4539
}
4640

0 commit comments

Comments
 (0)