Skip to content

Commit 45f4689

Browse files
authored
Adding synthesizer example (#93)
Adding examples for synthesizing single-table and multi-table data. Also, fixing a couple of bugs that appeared while testing the examples.
1 parent 63c336a commit 45f4689

File tree

16 files changed

+408
-17
lines changed

16 files changed

+408
-17
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,7 @@ examples/training/single_table/data/**
4747
examples/training/single_table/results/**
4848
examples/training/multi_table/data/**
4949
examples/training/multi_table/results/**
50+
examples/synthesizing/single_table/data/**
51+
examples/synthesizing/single_table/results/**
52+
examples/synthesizing/multi_table/data/**
53+
examples/synthesizing/multi_table/results/**
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Multi-Table Synthesizing Example
2+
3+
This example will go over synthesizing data for a multi-table dataset from the ground
4+
up using the code in this toolkit.
5+
6+
7+
## Downloading data
8+
9+
First, we need the data. Download it from this
10+
[Google Drive link](https://drive.google.com/file/d/1Ao222l4AJjG54-HDEGCWkIfzRbl9_IKa/view?usp=drive_link),
11+
extract the files and place them in a `/data` folder within this folder
12+
(`examples/synthesizing/multi_table`).
13+
14+
> [!NOTE]
15+
> If you wish to change the data folder, you can do so by editing the `base_data_dir` attribute
16+
> of the [`config.yaml`](config.yaml) file.
17+
18+
It will contain data for 8 tables: `account`, `card`, `client`, `disp`, `district`, `loan`, `order`,
19+
and `trans`. For each table there will be two files:
20+
- `{table_name}.csv`: The table's data.
21+
- `{table_name}_domain.json`: Metadata about the columns in the table's data, such as data types and sizes.
22+
23+
Additionally, you will find one more file:
24+
- `dataset_meta.json`: Metadata about the relationship between the tables. It will describe which tables
25+
are associated with which other tables.
26+
27+
28+
## Kicking off synthesizing
29+
30+
If there is a `/results` folder within this folder (`examples/synthesizing/multi_table`)
31+
from a previous training run, we will use that data to kick off synthesizing.
32+
For example, you can copy the results from another run (e.g. `examples.training.multi_table.run_training`)
33+
and paste them here and it will be picked up by this example.
34+
35+
The [`config.yaml`](config.yaml) file contains the parameters for the synthesizing and also
36+
for training, in case there is a need to run that. Please take a look at them before kicking
37+
off the synthesizing process and edit them as necessary.
38+
39+
To kick off synthesizing, simply run the command below from the project's root folder:
40+
41+
```bash
42+
python -m examples.synthesizing.multi_table.run_synthesizing
43+
```
44+
45+
## Results
46+
47+
It will save the result files inside a `/results` folder within this folder
48+
(`examples/synthesizing/multi_table`).
49+
50+
> [!NOTE]
51+
> If you wish to change the save folder, you can do so by editing the `results_dir` attribute
52+
> of the [`config.yaml`](config.yaml) file.
53+
54+
In the `/results/before_matching/` folder, there will be a file called `synthetic_tables.pkl`,
55+
which is a pickle file containing the synthetic data before the matching process, in case
56+
it's needed.
57+
58+
The `/results/multi_table_synthesizing` folder will contain the final synthesized
59+
data, organized per table, in the form of `.csv` files with the following naming pattern:
60+
`/results/multi_table_synthesizing/{table_name}/_final/{table_name}_synthetic.csv`.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Training example configuration
2+
# Base data directory (can be overridden from command line)
3+
base_data_dir: examples/synthesizing/multi_table/data
4+
results_dir: examples/synthesizing/multi_table/results
5+
6+
# diffusion_config, clustering_config, and classifier_config are only required
7+
# when training a new model from scratch
8+
diffusion_config:
9+
d_layers: [512, 1024, 1024, 1024, 1024, 512]
10+
dropout: 0.0
11+
num_timesteps: 2000
12+
model_type: mlp
13+
iterations: 20000
14+
batch_size: 4096
15+
lr: 0.0006
16+
gaussian_loss_type: mse
17+
weight_decay: 1e-05
18+
scheduler: cosine
19+
data_split_ratios: [0.99, 0.005, 0.005]
20+
21+
clustering_config:
22+
parent_scale: 1.0
23+
num_clusters: 50
24+
clustering_method: kmeans_and_gmm
25+
26+
classifier_config:
27+
d_layers: [128, 256, 512, 1024, 512, 256, 128]
28+
lr: 0.0001
29+
dim_t: 128
30+
batch_size: 4096
31+
iterations: 20000
32+
33+
# Synthesizing configuration
34+
general_config:
35+
data_dir: examples/synthesizing/multi_table/data
36+
test_data_dir: examples/synthesizing/multi_table/data
37+
exp_name: multi_table_synthesizing
38+
workspace_dir: examples/synthesizing/multi_table/results
39+
sample_prefix: ""
40+
41+
sampling_config:
42+
batch_size: 20000
43+
classifier_scale: 1.0
44+
45+
matching_config:
46+
num_matching_clusters: 1
47+
matching_batch_size: 1000
48+
unique_matching: True
49+
no_matching: False
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pickle
2+
from logging import INFO
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import hydra
7+
from omegaconf import DictConfig
8+
9+
from examples.training.multi_table import run_training
10+
from midst_toolkit.common.config import GeneralConfig, MatchingConfig, SamplingConfig
11+
from midst_toolkit.common.logger import TOOLKIT_LOGGER, log
12+
from midst_toolkit.models.clavaddpm.data_loaders import load_tables
13+
from midst_toolkit.models.clavaddpm.enumerations import Relation
14+
from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing
15+
16+
17+
# Preventing some excessive logging
18+
TOOLKIT_LOGGER.setLevel(INFO)
19+
20+
21+
@hydra.main(config_path=".", config_name="config", version_base=None)
22+
def main(config: DictConfig) -> None:
23+
"""
24+
Run the synthesizing pipeline for a multi-table diffusion model.
25+
26+
It will load the config and then data from the `config.base_data_dir` folder,
27+
train the model, synthesize the data and save the results in the
28+
`config.results_dir` folder.
29+
30+
It will first look for a pre-trained model in the `config.results_dir` folder.
31+
If it doesn't find one, it will train a new model from scratch.
32+
33+
Args:
34+
config: Training and synthesizing configuration as an OmegaConf DictConfig object.
35+
"""
36+
log(INFO, f"Checking for a pre-trained model in {config.results_dir}...")
37+
38+
_, relation_order, _ = load_tables(Path(config.base_data_dir))
39+
40+
model_file_paths: dict[Relation, dict[str, Any]] = {}
41+
for relation in relation_order:
42+
model_file_path = Path(config.results_dir) / "models" / f"{relation[0]}_{relation[1]}_ckpt.pkl"
43+
model_file_paths[relation] = {
44+
"file_path": model_file_path,
45+
"exists": model_file_path.exists(),
46+
}
47+
48+
clustering_results_file = Path(config.results_dir) / "cluster_ckpt.pkl"
49+
50+
if all(result["exists"] for result in model_file_paths.values()) and clustering_results_file.exists():
51+
log(INFO, f"Found previous results in {config.results_dir}. Skipping training.")
52+
else:
53+
log(INFO, "Not all previous results found. Training a new model from scratch.")
54+
log(INFO, f"Summary of results: {model_file_paths}")
55+
log(INFO, f"Clustering results file: {clustering_results_file} exists? {clustering_results_file.exists()}")
56+
run_training.main(config)
57+
58+
log(INFO, "Loading models...")
59+
60+
models = {}
61+
for relation in relation_order:
62+
with open(model_file_paths[relation]["file_path"], "rb") as f:
63+
models[relation] = pickle.load(f)
64+
65+
with open(clustering_results_file, "rb") as f:
66+
clustering_result = pickle.load(f)
67+
68+
tables = clustering_result["tables"]
69+
all_group_lengths_prob_dicts = clustering_result["all_group_lengths_prob_dicts"]
70+
71+
log(INFO, "Synthesizing data...")
72+
73+
clava_synthesizing(
74+
tables,
75+
relation_order,
76+
Path(config.results_dir),
77+
models,
78+
GeneralConfig(**config.general_config),
79+
SamplingConfig(**config.sampling_config),
80+
MatchingConfig(**config.matching_config),
81+
all_group_lengths_prob_dicts,
82+
)
83+
84+
log(INFO, "Data synthesized successfully.")
85+
86+
87+
if __name__ == "__main__":
88+
main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Single-Table Synthesizing Example
2+
3+
This example will go over synthesizing data for a single-table dataset from the ground
4+
up using the code in this toolkit.
5+
6+
7+
## Downloading data
8+
9+
First, we need the data. Download it from this
10+
[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link),
11+
extract the files and place them in a `/data` folder within this folder
12+
(`examples/synthesizing/single_table`).
13+
14+
> [!NOTE]
15+
> If you wish to change the data folder, you can do so by editing the `base_data_dir` attribute
16+
> of the [`config.yaml`](config.yaml) file.
17+
18+
Here is a description of the files that have been extracted:
19+
- `trans.csv`: The training data. It consists of information about bank transactions and it
20+
contains 20,000 data points.
21+
- `trans_domain.json`: Metadata about the columns in `trans.csv`, such as data types and sizes.
22+
- `dataset_meta.json`: Metadata about the relationship between the tables. Since this is a
23+
single-table example, it will only contain information about the `trans` table.
24+
25+
26+
## Kicking off synthesizing
27+
28+
If there is a `/results` folder within this folder (`examples/synthesizing/single_table`)
29+
from a previous training run, we will use that data to kick off synthesizing.
30+
For example, you can copy the results from another run (e.g. `examples.training.single_table.run_training`)
31+
and paste them here and it will be picked up by this example.
32+
33+
The [`config.yaml`](config.yaml) file contains the parameters for the synthesizing and also
34+
for training, in case there is a need to run that. Please take a look at them before kicking
35+
off the synthesizing process and edit them as necessary.
36+
37+
To kick off synthesizing, simply run the command below from the project's root folder:
38+
39+
```bash
40+
python -m examples.synthesizing.single_table.run_synthesizing
41+
```
42+
43+
## Results
44+
45+
It will save the result files inside a `/results` folder within this folder
46+
(`examples/synthesizing/single_table`).
47+
48+
> [!NOTE]
49+
> If you wish to change the save folder, you can do so by editing the `results_dir` attribute
50+
> of the [`config.yaml`](config.yaml) file.
51+
52+
In the `/results/before_matching/` folder, there will be a file called `synthetic_tables.pkl`,
53+
which is a pickle file containing the synthetic data before the matching process, in case
54+
it's needed.
55+
56+
The `/results/single_table_synthesizing` folder will contain the final synthesized
57+
data, organized per table. In this single-table example, there is only going to be one
58+
synthesized table under `/results/single_table_synthesizing/trans/_final/trans_synthetic.csv`.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Training example configuration
2+
# Base data directory (can be overridden from command line)
3+
base_data_dir: examples/synthesizing/single_table/data
4+
results_dir: examples/synthesizing/single_table/results
5+
6+
# diffusion_config is only required when training a new model from scratch
7+
diffusion_config:
8+
d_layers: [512, 1024, 1024, 1024, 1024, 512]
9+
dropout: 0.0
10+
num_timesteps: 2000
11+
model_type: mlp
12+
iterations: 20000
13+
batch_size: 4096
14+
lr: 0.0006
15+
gaussian_loss_type: mse
16+
weight_decay: 1e-05
17+
scheduler: cosine
18+
data_split_ratios: [0.99, 0.005, 0.005]
19+
20+
# Synthesizing configuration
21+
general_config:
22+
data_dir: examples/synthesizing/single_table/data
23+
test_data_dir: examples/synthesizing/single_table/data
24+
exp_name: single_table_synthesizing
25+
workspace_dir: examples/synthesizing/single_table/results
26+
sample_prefix: ""
27+
28+
sampling_config:
29+
batch_size: 20000
30+
classifier_scale: 1.0
31+
32+
matching_config:
33+
num_matching_clusters: 1
34+
matching_batch_size: 1000
35+
unique_matching: True
36+
no_matching: False
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pickle
2+
from logging import INFO
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import hydra
7+
from omegaconf import DictConfig
8+
9+
from examples.training.single_table import run_training
10+
from midst_toolkit.common.config import GeneralConfig, MatchingConfig, SamplingConfig
11+
from midst_toolkit.common.logger import TOOLKIT_LOGGER, log
12+
from midst_toolkit.models.clavaddpm.data_loaders import load_tables
13+
from midst_toolkit.models.clavaddpm.enumerations import Relation
14+
from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing
15+
16+
17+
# Preventing some excessive logging
18+
TOOLKIT_LOGGER.setLevel(INFO)
19+
20+
21+
@hydra.main(config_path=".", config_name="config", version_base=None)
22+
def main(config: DictConfig) -> None:
23+
"""
24+
Run the synthesizing pipeline for a single-table diffusion model.
25+
26+
It will load the config and then data from the `config.base_data_dir` folder,
27+
train the model, synthesize the data and save the results in the
28+
`config.results_dir` folder.
29+
30+
It will first look for a pre-trained model in the `config.results_dir` folder.
31+
If it doesn't find one, it will train a new model from scratch.
32+
33+
Args:
34+
config: Training and synthesizing configuration as an OmegaConf DictConfig object.
35+
"""
36+
log(INFO, f"Checking for a pre-trained model in {config.results_dir}...")
37+
38+
tables, relation_order, _ = load_tables(Path(config.base_data_dir))
39+
40+
assert len(relation_order) == 1 and relation_order[0][0] is None, (
41+
"Relation order is not configured for single-table. "
42+
"For multi-table synthesizing, please use the `examples.synthesizing.multi_table.run_synthesizing` example. "
43+
f"Relation order: {relation_order}"
44+
)
45+
46+
model_file_paths: dict[Relation, dict[str, Any]] = {}
47+
for relation in relation_order:
48+
model_file_path = Path(config.results_dir) / "models" / f"{relation[0]}_{relation[1]}_ckpt.pkl"
49+
model_file_paths[relation] = {
50+
"file_path": model_file_path,
51+
"exists": model_file_path.exists(),
52+
}
53+
54+
if all(result["exists"] for result in model_file_paths.values()):
55+
log(INFO, f"Found previous results in {config.results_dir}. Skipping training.")
56+
else:
57+
log(INFO, "Not all previous results found. Training a new model from scratch.")
58+
log(INFO, f"Summary of results: {model_file_paths}")
59+
run_training.main(config)
60+
61+
log(INFO, "Loading models...")
62+
63+
models = {}
64+
for relation in relation_order:
65+
with open(model_file_paths[relation]["file_path"], "rb") as f:
66+
models[relation] = pickle.load(f)
67+
68+
log(INFO, "Synthesizing data...")
69+
70+
clava_synthesizing(
71+
tables,
72+
relation_order,
73+
Path(config.results_dir),
74+
models,
75+
GeneralConfig(**config.general_config),
76+
SamplingConfig(**config.sampling_config),
77+
MatchingConfig(**config.matching_config),
78+
)
79+
80+
log(INFO, "Data synthesized successfully.")
81+
82+
83+
if __name__ == "__main__":
84+
main()

0 commit comments

Comments
 (0)