Skip to content

Commit 2502a7e

Browse files
authored
Merge branch 'main' into bz/tf_example
2 parents bd3da0b + 54fc5a2 commit 2502a7e

File tree

17 files changed

+1925
-1355
lines changed

17 files changed

+1925
-1355
lines changed

.github/workflows/code_checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- uses: actions/[email protected]
3333

3434
- name: Install uv
35-
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a
35+
uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867
3636
with:
3737
# Install a specific version of uv.
3838
version: "0.5.21"

.github/workflows/docs.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
uses: actions/[email protected]
4646

4747
- name: Install uv
48-
uses: astral-sh/[email protected].5
48+
uses: astral-sh/[email protected].6
4949
with:
5050
version: "0.5.21"
5151
enable-cache: true
@@ -65,7 +65,7 @@ jobs:
6565
run: touch site/.nojekyll
6666

6767
- name: Upload artifact
68-
uses: actions/upload-artifact@v5
68+
uses: actions/upload-artifact@v6
6969
with:
7070
name: docs-site
7171
path: site/
@@ -85,7 +85,7 @@ jobs:
8585
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
8686
8787
- name: Download artifact
88-
uses: actions/download-artifact@v6
88+
uses: actions/download-artifact@v7
8989
with:
9090
name: docs-site
9191
path: site
@@ -94,7 +94,7 @@ jobs:
9494
run: touch site/.nojekyll
9595

9696
- name: Deploy to Github pages
97-
uses: JamesIves/[email protected].4
97+
uses: JamesIves/[email protected].6
9898
with:
9999
branch: gh-pages
100100
folder: site

.github/workflows/integration_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- uses: actions/[email protected]
4545

4646
- name: Install uv
47-
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a
47+
uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867
4848
with:
4949
# Install a specific version of uv.
5050
version: "0.5.21"

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- uses: actions/[email protected]
2020

2121
- name: Install uv
22-
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a
22+
uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867
2323
with:
2424
# Install a specific version of uv.
2525
version: "0.5.21"

.github/workflows/unit_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- uses: actions/[email protected]
4545

4646
- name: Install uv
47-
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a
47+
uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867
4848
with:
4949
# Install a specific version of uv.
5050
version: "0.5.21"

.gitignore

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ wheels/
3131

3232
# Data files
3333
examples/**/*data/
34+
examples/**/*results/
3435
examples/**/*.csv
3536
examples/**/*.npy
3637

@@ -43,16 +44,6 @@ outputs/
4344
# mkdocs site
4445
site/
4546

46-
# Training examples
47-
examples/training/single_table/data/**
48-
examples/training/single_table/results/**
49-
examples/training/multi_table/data/**
50-
examples/training/multi_table/results/**
51-
examples/synthesizing/single_table/data/**
52-
examples/synthesizing/single_table/results/**
53-
examples/synthesizing/multi_table/data/**
54-
examples/synthesizing/multi_table/results/**
55-
5647
# Test artifacts
5748
tests/integration/attacks/tartan_federer/assets/tabddpm_models/**/challenge_label_predictions.csv
5849
tests/integration/attacks/tartan_federer/assets/tartan_federer_attack_results

examples/gan/README.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# CTGAN Single-Table Example
2+
3+
This example will go over training a single-table [CTGAN](https://arxiv.org/pdf/1907.00503)
4+
model using the [CTGAN](https://github.com/sdv-dev/CTGAN/) library and then synthesizing
5+
some data afterwards.
6+
7+
8+
## Downloading data
9+
10+
First, we need the data. Download it from this
11+
[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link),
12+
extract the files and place them in a `/data` folder in within this folder
13+
(`examples/gan`).
14+
15+
> [!NOTE]
16+
> If you wish to change the data folder, you can do so by editing the `base_data_dir` attribute
17+
> of the [`config.yaml`](config.yaml) file.
18+
19+
Here is a description of the files that have been extracted:
20+
- `trans.csv`: The training data. It consists of information about bank transactions and it
21+
contains 20,000 data points.
22+
- `trans_domain.json`: Metadata about the columns in `trans.csv`, such as data types and sizes.
23+
- `dataset_meta.json`: Metadata about the relationship between the tables. Since this is a
24+
single-table example, it will only contain information about the `trans` table.
25+
- `meta_info.json`: Metadata about the dataset, namely which columns are numerical and
26+
which ones are categorical, the target column and the task type (e.g. `regression`).
27+
28+
29+
## Kicking off training
30+
31+
To kick off training, simply run the command below from the project's root folder:
32+
33+
```bash
34+
python -m examples.gan.train
35+
```
36+
37+
38+
## Training results
39+
40+
The result files will be saved inside a `/results` folder within this folder
41+
(`examples/gan`).
42+
43+
> [!NOTE]
44+
> If you wish to change the save folder, you can do so by editing the `results_dir` attribute
45+
> of the [`config.yaml`](config.yaml) file.
46+
47+
In the `/results` folder, there will be a file called `trained_ctgan_model.pkl`,
48+
which is a pickle file containing the trained model. You can load it using CTGAN's
49+
`load` function:
50+
51+
```python
52+
import pickle
53+
from ctgan import CTGAN
54+
55+
results_file = Path("examples/gan/results/trained_ctgan_model.pkl")
56+
57+
ctgan = CTGAN.load(results_file)
58+
```
59+
60+
## Synthesizing data
61+
62+
To synthesize some data with the trained model, run:
63+
64+
```bash
65+
python -m examples.gan.synthesize
66+
```
67+
68+
If there is already a trained model in the `/results` folder, it will use that model.
69+
Otherwise it will train one from scratch. At the end of the script, it will save the
70+
synthesized data to `/results/trans_synthetic.csv`.
71+
72+
73+
## Evaluating the quality of the synthetic data
74+
75+
### Alpha Precision
76+
77+
To run a round of evaluation with [Alpha Precision](https://arxiv.org/abs/2301.07573)
78+
metrics on a set of synthetic data, run the `evaluate.py` script:
79+
80+
```bash
81+
python -m midst_toolkit.evaluation.quality.scripts.midst_alpha_precision_eval \
82+
--synthetic_data_path examples/gan/results/trans_synthetic.csv \
83+
--real_data examples/gan/data/trans.csv \
84+
--meta_info_path examples/gan/data/meta_info.json \
85+
--save_directory examples/gan/results/
86+
```
87+
88+
It will save the evaluation results under the `/results/model.txt` file.
89+
90+
### Additional Metrics
91+
92+
The calculation of additional metrics are set up in the `evaluate.py` file. They are the
93+
Kolmogorov-Smirnov (KS) test, Total Variation Distance (TVD), Correlation Matrix Difference
94+
and Mutual Information Difference.
95+
96+
To compute those metrics, you can run the command below. The name of the table should be
97+
defined in the `dataset_meta.json` file, and the file for synthetic data should be under
98+
`/data/{table_name}.csv` for the real data and `/results/{table_name}_synthetic.csv`
99+
for the synthetic data.
100+
101+
```bash
102+
python -m examples.gan.evaluate
103+
```
104+
105+
The results will be saved in the `/results/evaluation.json` file.

examples/gan/config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Training example configuration
2+
# Base data directory (can be overridden from command line)
3+
base_data_dir: examples/gan/data
4+
results_dir: examples/gan/results
5+
6+
training:
7+
epochs: 300
8+
verbose: True
9+
10+
synthesizing:
11+
sample_size: 20000

examples/gan/evaluate.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
from logging import INFO
3+
from pathlib import Path
4+
5+
import hydra
6+
import pandas as pd
7+
from omegaconf import DictConfig
8+
9+
from examples.gan.utils import get_table_name
10+
from midst_toolkit.common.logger import log
11+
from midst_toolkit.evaluation.quality.correlation_matrix_difference import CorrelationMatrixDifference
12+
from midst_toolkit.evaluation.quality.kolmogorov_smirnov_total_variation import KolmogorovSmirnovAndTotalVariation
13+
from midst_toolkit.evaluation.quality.mutual_information_difference import MutualInformationDifference
14+
15+
16+
@hydra.main(config_path=".", config_name="config", version_base=None)
17+
def main(config: DictConfig) -> None:
18+
"""
19+
Run the evaluation pipeline for the Kolmogorov-Smirnov and Total Variation Distance metrics.
20+
21+
It will load the config and then data from the `config.base_data_dir` folder for the table
22+
name (from the `dataset_meta.json` file) and the real data under `{table_name}.csv`, and
23+
the synthetic data from the `config.results_dir` folder under `{table_name}_synthetic.csv`,
24+
and then compute the Kolmogorov-Smirnov and Total Variation Distance metrics.
25+
26+
It will also need the meta_info.json file for the information about categorical and numerical
27+
columns.
28+
29+
The results will be saved in the `config.results_dir` folder under `ks_tvd_evaluation.json`.
30+
31+
Args:
32+
config: Configuration as an OmegaConf DictConfig object.
33+
"""
34+
log(INFO, "Loading data...")
35+
36+
table_name = get_table_name(config.base_data_dir)
37+
38+
real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv")
39+
synthetic_data = pd.read_csv(Path(config.results_dir) / f"{table_name}_synthetic.csv")
40+
41+
with open(Path(config.base_data_dir) / "meta_info.json", "r") as f:
42+
meta_info = json.load(f)
43+
44+
numerical_columns = [real_data.columns[i] for i in meta_info["num_col_idx"]]
45+
categorical_columns = [real_data.columns[i] for i in meta_info["cat_col_idx"]]
46+
47+
results = {}
48+
49+
# KS and TVD
50+
ks_tvd_metric = KolmogorovSmirnovAndTotalVariation(categorical_columns, numerical_columns, do_preprocess=True)
51+
ks_tvd_score = ks_tvd_metric.compute(real_data, synthetic_data)
52+
53+
log(INFO, f"Kolmogorov-Smirnov and Total Variation Distance score: {ks_tvd_score}")
54+
results["ks_tvd"] = ks_tvd_score
55+
56+
# Correlation Matrix Difference
57+
cmd_metric = CorrelationMatrixDifference(categorical_columns, numerical_columns, do_preprocess=True)
58+
cmd_result = cmd_metric.compute(real_data, synthetic_data)
59+
60+
log(INFO, f"Correlation Matrix Difference score: {cmd_result}")
61+
results["correlation_matrix_difference"] = cmd_result
62+
63+
# Mutual Information Difference
64+
mid_metric = MutualInformationDifference(categorical_columns, numerical_columns, do_preprocess=True)
65+
mid_result = mid_metric.compute(real_data, synthetic_data)
66+
mid_result["score"] = mid_result["mutual_inf_diff"] / mid_result["mi_mat_dims"]
67+
68+
log(INFO, f"Mutual Information Difference score: {mid_result}")
69+
results["mutual_information_difference"] = mid_result
70+
71+
log(INFO, "Saving results...")
72+
with open(Path(config.results_dir) / "evaluation.json", "w") as f:
73+
json.dump(results, f, indent=4)
74+
75+
log(INFO, "Done!")
76+
77+
78+
if __name__ == "__main__":
79+
main()

examples/gan/synthesize.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from logging import INFO
2+
from pathlib import Path
3+
4+
import hydra
5+
from omegaconf import DictConfig
6+
from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped]
7+
8+
from examples.gan.train import main as train_main
9+
from examples.gan.utils import get_table_name
10+
from midst_toolkit.common.logger import log
11+
12+
13+
@hydra.main(config_path=".", config_name="config", version_base=None)
14+
def main(config: DictConfig) -> None:
15+
"""
16+
Run the synthesizing pipeline for a single-table CTGAN model.
17+
18+
It will load the config and then data from the `config.base_data_dir` folder,
19+
load the trained model (or train one if it doesn't exist) and save the results
20+
in the `config.results_dir` folder.
21+
22+
Args:
23+
config: Configuration as an OmegaConf DictConfig object.
24+
"""
25+
results_file = Path(config.results_dir) / "trained_ctgan_model.pkl"
26+
27+
if not results_file.exists():
28+
log(INFO, f"Trained model not found at {results_file}. Training a new model from scratch.")
29+
train_main(config)
30+
31+
log(INFO, f"Loading model from {results_file}...")
32+
ctgan = CTGANSynthesizer.load(results_file)
33+
34+
log(INFO, f"Synthesizing data of size {config.synthesizing.sample_size}...")
35+
synthetic_data = ctgan.sample(num_rows=config.synthesizing.sample_size)
36+
37+
table_name = get_table_name(config.base_data_dir)
38+
synthetic_data_file = Path(config.results_dir) / f"{table_name}_synthetic.csv"
39+
40+
log(INFO, f"Saving synthetic data to {synthetic_data_file}...")
41+
synthetic_data.to_csv(synthetic_data_file, index=False)
42+
43+
log(INFO, "Done!")
44+
45+
46+
if __name__ == "__main__":
47+
main()

0 commit comments

Comments
 (0)