Skip to content

Commit 33254e1

Browse files
committed
Fatemeh's CR
1 parent 0bafdf5 commit 33254e1

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

examples/gan/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from omegaconf import DictConfig
88
from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped]
99

10-
from examples.gan.utils import get_metadata, get_table_name
10+
from examples.gan.utils import get_single_table_svd_metadata, get_table_name
1111
from midst_toolkit.common.logger import log
1212

1313

@@ -31,7 +31,7 @@ def main(config: DictConfig) -> None:
3131

3232
real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv")
3333

34-
metadata, real_data_without_ids = get_metadata(real_data, domain_info)
34+
metadata, real_data_without_ids = get_single_table_svd_metadata(real_data, domain_info)
3535

3636
log(INFO, "Fitting CTGAN...")
3737

examples/gan/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def get_table_name(base_data_dir: Path) -> str:
2727
return list(dataset_meta["tables"].keys())[0]
2828

2929

30-
def get_metadata(
30+
def get_single_table_svd_metadata(
3131
data: pd.DataFrame,
3232
domain_dictionary: dict[str, Any] | None = None,
3333
) -> tuple[SingleTableMetadata, pd.DataFrame]:
3434
"""
35-
Get the metadata for a single-table dataset.
35+
Get the metadata for a single-table dataset for SDV models.
3636
3737
Args:
3838
data: The dataframe containing the data.
@@ -43,7 +43,7 @@ def get_metadata(
4343
"""
4444
metadata = SingleTableMetadata()
4545
data_without_ids = data.drop(columns=[column_name for column_name in data.columns if "_id" in column_name])
46-
metadata.detect_from_dataframe(data_without_ids)
46+
metadata.detect_from_dataframe(data_without_ids) # Starts up the metadata info from the dataframe's columns.
4747

4848
if domain_dictionary is not None:
4949
for column_name in data_without_ids.columns:

0 commit comments

Comments
 (0)