|
139 | 139 | ], |
140 | 140 | "source": [ |
141 | 141 | "# 1. Initiate metadata\n", |
142 | | - "metadata = MissingDataHandler()\n", |
| 142 | + "md_handler = MissingDataHandler()\n", |
143 | 143 | "\n", |
144 | 144 | "# 1.1 Get data types\n", |
145 | | - "column_dtypes = metadata.get_column_dtypes(df)\n", |
146 | | - "print(\"Column Data Types:\", column_dtypes)" |
| 145 | + "metadata= md_handler.get_column_dtypes(df)\n", |
| 146 | + "print(\"Column Data Types:\", metadata)" |
147 | 147 | ] |
148 | 148 | }, |
149 | 149 | { |
|
185 | 185 | ], |
186 | 186 | "source": [ |
187 | 187 | "# 2.1 Detect type of missingness\n", |
188 | | - "missingness_dict = metadata.detect_missingness(df)\n", |
| 188 | + "missingness_dict = md_handler.detect_missingness(df)\n", |
189 | 189 | "print(\"Detected missingness type:\", missingness_dict)" |
190 | 190 | ] |
191 | 191 | }, |
192 | 192 | { |
193 | 193 | "cell_type": "code", |
194 | | - "execution_count": 8, |
| 194 | + "execution_count": 7, |
195 | 195 | "metadata": {}, |
196 | 196 | "outputs": [ |
197 | 197 | { |
|
210 | 210 | ], |
211 | 211 | "source": [ |
212 | 212 | "# 2.2 Impute missing values\n", |
213 | | - "df_imputed = metadata.apply_imputation(df, missingness_dict)\n", |
| 213 | + "df_imputed = md_handler.apply_imputation(df, missingness_dict)\n", |
214 | 214 | "\n", |
215 | 215 | "print(df_imputed.isnull().sum())" |
216 | 216 | ] |
217 | 217 | }, |
218 | 218 | { |
219 | 219 | "cell_type": "code", |
220 | | - "execution_count": 9, |
| 220 | + "execution_count": 8, |
221 | 221 | "metadata": {}, |
222 | 222 | "outputs": [ |
223 | 223 | { |
|
321 | 321 | ], |
322 | 322 | "source": [ |
323 | 323 | "# 3. Instantiate the DataProcessor with column_dtypes\n", |
324 | | - "processor = DataProcessor(column_dtypes)\n", |
| 324 | + "processor = DataProcessor(metadata)\n", |
325 | 325 | "\n", |
326 | 326 | "# 3.1 Preprocess the data: transforms raw data into a numerical format\n", |
327 | 327 | "processed_data = processor.preprocess(df)\n", |
|
331 | 331 | }, |
332 | 332 | { |
333 | 333 | "cell_type": "code", |
334 | | - "execution_count": null, |
| 334 | + "execution_count": 9, |
335 | 335 | "metadata": {}, |
336 | 336 | "outputs": [ |
337 | 337 | { |
|
348 | 348 | "cart.fit(processed_data)" |
349 | 349 | ] |
350 | 350 | }, |
| 351 | + { |
| 352 | + "cell_type": "code", |
| 353 | + "execution_count": 11, |
| 354 | + "metadata": {}, |
| 355 | + "outputs": [ |
| 356 | + { |
| 357 | + "name": "stdout", |
| 358 | + "output_type": "stream", |
| 359 | + "text": [ |
| 360 | + "Synthetic Processed Data:\n" |
| 361 | + ] |
| 362 | + }, |
| 363 | + { |
| 364 | + "data": { |
| 365 | + "text/html": [ |
| 366 | + "<div>\n", |
| 367 | + "<style scoped>\n", |
| 368 | + " .dataframe tbody tr th:only-of-type {\n", |
| 369 | + " vertical-align: middle;\n", |
| 370 | + " }\n", |
| 371 | + "\n", |
| 372 | + " .dataframe tbody tr th {\n", |
| 373 | + " vertical-align: top;\n", |
| 374 | + " }\n", |
| 375 | + "\n", |
| 376 | + " .dataframe thead th {\n", |
| 377 | + " text-align: right;\n", |
| 378 | + " }\n", |
| 379 | + "</style>\n", |
| 380 | + "<table border=\"1\" class=\"dataframe\">\n", |
| 381 | + " <thead>\n", |
| 382 | + " <tr style=\"text-align: right;\">\n", |
| 383 | + " <th></th>\n", |
| 384 | + " <th>sex</th>\n", |
| 385 | + " <th>age</th>\n", |
| 386 | + " <th>marital</th>\n", |
| 387 | + " <th>ls</th>\n", |
| 388 | + " <th>smoke</th>\n", |
| 389 | + " </tr>\n", |
| 390 | + " </thead>\n", |
| 391 | + " <tbody>\n", |
| 392 | + " <tr>\n", |
| 393 | + " <th>0</th>\n", |
| 394 | + " <td>0</td>\n", |
| 395 | + " <td>-1.123252</td>\n", |
| 396 | + " <td>4</td>\n", |
| 397 | + " <td>2</td>\n", |
| 398 | + " <td>0</td>\n", |
| 399 | + " </tr>\n", |
| 400 | + " <tr>\n", |
| 401 | + " <th>1</th>\n", |
| 402 | + " <td>1</td>\n", |
| 403 | + " <td>0.704909</td>\n", |
| 404 | + " <td>3</td>\n", |
| 405 | + " <td>4</td>\n", |
| 406 | + " <td>1</td>\n", |
| 407 | + " </tr>\n", |
| 408 | + " <tr>\n", |
| 409 | + " <th>2</th>\n", |
| 410 | + " <td>0</td>\n", |
| 411 | + " <td>1.583713</td>\n", |
| 412 | + " <td>5</td>\n", |
| 413 | + " <td>3</td>\n", |
| 414 | + " <td>0</td>\n", |
| 415 | + " </tr>\n", |
| 416 | + " <tr>\n", |
| 417 | + " <th>3</th>\n", |
| 418 | + " <td>0</td>\n", |
| 419 | + " <td>-0.127991</td>\n", |
| 420 | + " <td>3</td>\n", |
| 421 | + " <td>4</td>\n", |
| 422 | + " <td>1</td>\n", |
| 423 | + " </tr>\n", |
| 424 | + " <tr>\n", |
| 425 | + " <th>4</th>\n", |
| 426 | + " <td>0</td>\n", |
| 427 | + " <td>0.868010</td>\n", |
| 428 | + " <td>3</td>\n", |
| 429 | + " <td>4</td>\n", |
| 430 | + " <td>0</td>\n", |
| 431 | + " </tr>\n", |
| 432 | + " </tbody>\n", |
| 433 | + "</table>\n", |
| 434 | + "</div>" |
| 435 | + ], |
| 436 | + "text/plain": [ |
| 437 | + " sex age marital ls smoke\n", |
| 438 | + "0 0 -1.123252 4 2 0\n", |
| 439 | + "1 1 0.704909 3 4 1\n", |
| 440 | + "2 0 1.583713 5 3 0\n", |
| 441 | + "3 0 -0.127991 3 4 1\n", |
| 442 | + "4 0 0.868010 3 4 0" |
| 443 | + ] |
| 444 | + }, |
| 445 | + "metadata": {}, |
| 446 | + "output_type": "display_data" |
| 447 | + } |
| 448 | + ], |
| 449 | + "source": [ |
| 450 | + "# 4.1 Preview generated synthetic data\n", |
| 451 | + "synthetic_processed = cart.sample(100)\n", |
| 452 | + "print(\"Synthetic Processed Data:\")\n", |
| 453 | + "display(synthetic_processed.head())" |
| 454 | + ] |
| 455 | + }, |
351 | 456 | { |
352 | 457 | "cell_type": "code", |
353 | 458 | "execution_count": null, |
354 | 459 | "metadata": {}, |
| 460 | + "outputs": [ |
| 461 | + { |
| 462 | + "ename": "KeyError", |
| 463 | + "evalue": "\"None of [Index(['income'], dtype='object')] are in the [columns]\"", |
| 464 | + "output_type": "error", |
| 465 | + "traceback": [ |
| 466 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 467 | + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", |
| 468 | + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# 4.2 Postprocess the synthetic data back to the original format\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m synthetic_data \u001b[38;5;241m=\u001b[39m \u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43msynthetic_processed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSynthetic Data in Original Format:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m display(synthetic_data\u001b[38;5;241m.\u001b[39mhead())\n", |
| 469 | + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/synthpop/processor/data_processor.py:90\u001b[0m, in \u001b[0;36mDataProcessor.postprocess\u001b[0;34m(self, synthetic_data)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m dtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumerical\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m col \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscalers:\n\u001b[1;32m 89\u001b[0m scaler \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscalers[col]\n\u001b[0;32m---> 90\u001b[0m synthetic_data[col] \u001b[38;5;241m=\u001b[39m scaler\u001b[38;5;241m.\u001b[39minverse_transform(\u001b[43msynthetic_data\u001b[49m\u001b[43m[\u001b[49m\u001b[43m[\u001b[49m\u001b[43mcol\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m dtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mboolean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 93\u001b[0m synthetic_data[col] \u001b[38;5;241m=\u001b[39m synthetic_data[col]\u001b[38;5;241m.\u001b[39mround()\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mbool\u001b[39m)\n", |
| 470 | + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/pandas/core/frame.py:4108\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4106\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_iterator(key):\n\u001b[1;32m 4107\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(key)\n\u001b[0;32m-> 4108\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_indexer_strict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcolumns\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 4110\u001b[0m \u001b[38;5;66;03m# take() does not accept boolean indexers\u001b[39;00m\n\u001b[1;32m 4111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(indexer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n", |
| 471 | + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/pandas/core/indexes/base.py:6200\u001b[0m, in \u001b[0;36mIndex._get_indexer_strict\u001b[0;34m(self, key, axis_name)\u001b[0m\n\u001b[1;32m 6197\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 6198\u001b[0m keyarr, indexer, new_indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reindex_non_unique(keyarr)\n\u001b[0;32m-> 6200\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_raise_if_missing\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkeyarr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6202\u001b[0m keyarr \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtake(indexer)\n\u001b[1;32m 6203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, Index):\n\u001b[1;32m 6204\u001b[0m \u001b[38;5;66;03m# GH 42790 - Preserve name from an Index\u001b[39;00m\n", |
| 472 | + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/pandas/core/indexes/base.py:6249\u001b[0m, in \u001b[0;36mIndex._raise_if_missing\u001b[0;34m(self, key, indexer, axis_name)\u001b[0m\n\u001b[1;32m 6247\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m nmissing:\n\u001b[1;32m 6248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m nmissing \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(indexer):\n\u001b[0;32m-> 6249\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNone of [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m] are in the [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00maxis_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 6251\u001b[0m not_found \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(ensure_index(key)[missing_mask\u001b[38;5;241m.\u001b[39mnonzero()[\u001b[38;5;241m0\u001b[39m]]\u001b[38;5;241m.\u001b[39munique())\n\u001b[1;32m 6252\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnot_found\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in index\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", |
| 473 | + "\u001b[0;31mKeyError\u001b[0m: \"None of [Index(['income'], dtype='object')] are in the [columns]\"" |
| 474 | + ] |
| 475 | + } |
| 476 | + ], |
| 477 | + "source": [ |
| 478 | + "# 4.2 Postprocess the synthetic data back to the original format and give preview of generated synthetic data\n", |
| 479 | + "synthetic_data = processor.postprocess(synthetic_processed)\n", |
| 480 | + "print(\"Synthetic Data in Original Format:\")\n", |
| 481 | + "display(synthetic_data.head())" |
| 482 | + ] |
| 483 | + }, |
| 484 | + { |
| 485 | + "cell_type": "code", |
| 486 | + "execution_count": 10, |
| 487 | + "metadata": {}, |
355 | 488 | "outputs": [], |
356 | 489 | "source": [ |
357 | 490 | "from synthpop.metrics import (\n", |
|
0 commit comments