|
10 | 10 |
|
11 | 11 | from keras.utils import clear_session |
12 | 12 |
|
13 | | -logging.getLogger("bayesflow").setLevel(logging.DEBUG) |
14 | 13 | BASE = Path(__file__).resolve().parent |
15 | 14 | EPOCHS = 1000 |
16 | 15 | BATCH_SIZE = 128 |
17 | 16 | NUM_SAMPLES_INFERENCE = 1000 |
18 | 17 | MODELS = { |
19 | | - "flow_matching": (bf.networks.FlowMatching, {"subnet": "mlp"}), |
20 | | - "cot_flow_matching": (bf.networks.FlowMatching, {"use_optimal_transport": True, "subnet": "mlp"}), |
21 | | - "consistency_model": (bf.networks.ConsistencyModel, {"total_steps": EPOCHS*BATCH_SIZE, "subnet": "mlp"}), |
22 | | - "stable_consistency_model": (bf.experimental.StableConsistencyModel, {"subnet": "mlp"}), |
| 18 | + "flow_matching": (bf.networks.FlowMatching, {}), |
| 19 | + "cot_flow_matching": (bf.networks.FlowMatching, {"use_optimal_transport": True}), |
| 20 | + "consistency_model": (bf.networks.ConsistencyModel, {"total_steps": EPOCHS*BATCH_SIZE}), |
| 21 | + "stable_consistency_model": (bf.experimental.StableConsistencyModel, {}), |
23 | 22 | "diffusion_edm_vp": (bf.networks.DiffusionModel, { |
24 | | - "subnet": "mlp", |
25 | | - "noise_schedule": "edm", |
| 23 | + "noise_schedule": "edm", |
26 | 24 | "prediction_type": "F", |
27 | 25 | "schedule_kwargs": {"variance_type": "preserving"}}), |
28 | 26 | "diffusion_edm_ve": (bf.networks.DiffusionModel, { |
29 | | - "subnet": "mlp", |
30 | | - "noise_schedule": "edm", |
| 27 | + "noise_schedule": "edm", |
31 | 28 | "prediction_type": "F", |
32 | 29 | "schedule_kwargs": {"variance_type": "exploding"}}), |
33 | 30 | "diffusion_cosine_F": (bf.networks.DiffusionModel, { |
34 | | - "subnet": "mlp", |
35 | | - "noise_schedule": "cosine", |
| 31 | + "noise_schedule": "cosine", |
36 | 32 | "prediction_type": "F", }), |
37 | 33 | "diffusion_cosine_v": (bf.networks.DiffusionModel, { |
38 | | - "subnet": "mlp", |
39 | | - "noise_schedule": "cosine", |
| 34 | + "noise_schedule": "cosine", |
40 | 35 | "prediction_type": "velocity"}), |
41 | 36 | "diffusion_cosine_noise": (bf.networks.DiffusionModel, { |
42 | | - "subnet": "mlp", |
43 | | - "noise_schedule": "cosine", |
| 37 | + "noise_schedule": "cosine", |
44 | 38 | "prediction_type": "noise"}), |
45 | 39 | } |
46 | 40 |
|
|
0 commit comments