-
Notifications
You must be signed in to change notification settings - Fork 1
Ft/ensemble experiments test script #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… uses target synthetic data.
…rent experimental setups
| run_shadow_model_training: true # Set this to false if shadow models are already trained and saved | ||
| run_metaclassifier_training: true | ||
|
|
||
| target_model: # This is only used for testing the attack on a real target model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are attacking the target model tabddpm_21 with the trained metaclassifier.
| # The column name in the data to be used for stratified splitting. | ||
| column_to_stratify: "trans_type" # Attention: This value is not documented in the original codebase. | ||
| folder_ranges: #Specify folder ranges for any of the mentioned splits. | ||
| train: [[1, 20]] # Folders to be used for train data collection in the experiments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model IDs used for training the metaclassifier (attack model).
📝 WalkthroughWalkthroughThis pull request refactors the ensemble attack module's data flow and configuration architecture. The changes migrate from pickle-based persistence and dictionary-structured data to CSV/DataFrame-based persistence with explicit configuration paths. Key modifications include: (1) introducing a centralized Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–75 minutes Areas requiring extra attention:
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (19)
examples/ensemble_attack/data_configs/trans.json (1)
23-23: Review the regularization impact of setting dropout to 0.0.Dropout was reduced from 0.1 to 0.0, completely disabling this regularization mechanism. Combined with the massively expanded network architecture (line 15-21), this could lead to overfitting on small datasets or reduced generalization.
Verify whether:
- This change was intentional as part of the experimental setup
- The dataset size or early stopping strategies compensate for the lack of regularization
.gitignore (1)
56-58: Consider scoping the training log patterns more narrowly.The wildcard patterns
*.errand*.outwill ignore all files with these extensions across the entire repository. If these are SLURM job logs specific to the ensemble attack examples, consider scoping them more narrowly.Apply this diff to scope the patterns to the examples directory:
# Training Logs -*.err -*.out +examples/**/*.err +examples/**/*.outexamples/ensemble_attack/configs/original_attack_config.yaml (1)
76-80: Document the dramatic increase in computational requirements.The hyperparameters have been increased by orders of magnitude (e.g., diffusion iterations: 2→200000, synthetic samples: 200→20000, Optuna trials: 10→100). While these production values likely improve attack quality, they will dramatically increase runtime and resource consumption.
Consider:
- Adding comments in the config file indicating estimated runtime/resource requirements
- Documenting these changes in the PR description or a README
- Providing a separate "quick test" configuration with smaller values for development/testing
examples/ensemble_attack/run_test.sh (1)
17-17: Add error handling for virtual environment activation.The script assumes
.venvexists at the repository root without checking. If the virtual environment doesn't exist or activation fails, the script will continue and likely fail with unclear errors.Apply this diff to add error handling:
# This script sets up the environment and runs the ensemble attack example. -source .venv/bin/activate +if [ ! -f .venv/bin/activate ]; then + echo "Error: Virtual environment not found at .venv/" + echo "Please create it with: python -m venv .venv" + exit 1 +fi + +source .venv/bin/activate || { echo "Failed to activate virtual environment"; exit 1; }examples/ensemble_attack/run_train.sh (2)
17-17: Add error handling for virtual environment activation.The script assumes
.venvexists at the repository root without checking. Consider adding the same error handling recommended forrun_test.sh.Apply this diff:
# This script sets up the environment and runs the ensemble attack example. -source .venv/bin/activate +if [ ! -f .venv/bin/activate ]; then + echo "Error: Virtual environment not found at .venv/" + echo "Please create it with: python -m venv .venv" + exit 1 +fi + +source .venv/bin/activate || { echo "Failed to activate virtual environment"; exit 1; }
1-26: Consider consolidating the two SLURM scripts.
run_train.shandrun_test.share nearly identical, differing only in job name, time limit, and the Python module executed. This duplication could be reduced by using a parameterized script or a wrapper.Example approach:
#!/bin/bash # run_ensemble.sh <train|test> MODE=${1:-train} if [ "$MODE" = "train" ]; then TIME=12:00:00 MODULE=examples.ensemble_attack.run_attack elif [ "$MODE" = "test" ]; then TIME=5:00:00 MODULE=examples.ensemble_attack.test_attack_model else echo "Usage: $0 <train|test>" exit 1 fi #SBATCH --time=$TIME #SBATCH --job-name=ensemble_attack_$MODE # ... rest of config python -m $MODULEexamples/ensemble_attack/run_shadow_model_training.py (1)
26-27: Return aPathobject instead of a raw config string for consistencyThe new flow of saving only
train_result.synthetic_dataand returning its path looks good and matches the metaclassifier’s needs. However,config.shadow_training.target_synthetic_data_pathis almost certainly a string from Hydra, while the function is annotated as returningPathand other call sites treat it as such.To keep types consistent across the pipeline (including
run_attack.mainandrun_metaclassifier_training), consider converting toPathhere:- # Save the target model's synthetic data - target_model_synthetic_path = config.shadow_training.target_synthetic_data_path - target_synthetic_data.to_csv(target_model_synthetic_path, index=False) - - return target_model_synthetic_path + # Save the target model's synthetic data + target_model_synthetic_path = Path(config.shadow_training.target_synthetic_data_path) + target_synthetic_data.to_csv(target_model_synthetic_path, index=False) + + return target_model_synthetic_pathAlso applies to: 70-79
examples/ensemble_attack/run_attack.py (1)
69-80: Refresh assertion messages to reflect new names and ensure path type consistencyThe control flow around
target_model_synthetic_pathis correct, but there are two small clarity issues:
The assertion message for the shadows list still references the old name:
- Line 77:
"The attack_data_paths list must contain exactly three elements."
but the variable isshadow_data_paths.The assertion error message for the target path still references
target_data_pathrather than the new synthetic path naming.Consider updating the messages:
- assert len(shadow_data_paths) == 3, "The attack_data_paths list must contain exactly three elements." - assert target_model_synthetic_path is not None, ( - "The target_data_path must be provided for metaclassifier training." - ) + assert len(shadow_data_paths) == 3, "The shadow_data_paths list must contain exactly three elements." + assert target_model_synthetic_path is not None, ( + "The target_model_synthetic_path must be provided for metaclassifier training." + )Once
run_target_model_trainingis updated to return aPath(as suggested in its file),target_model_synthetic_pathwill be aPathin both branches, which keeps the type consistent going intorun_metaclassifier_training.Also applies to: 83-84
tests/unit/attacks/ensemble/test_rmia.py (1)
96-107: Fixture update correctly matches newcalculate_rmia_signalsAPIIncluding
target_synthetic_datadirectly inrmia_signal_datakeeps the fixture aligned with the refactoredcalculate_rmia_signals(**rmia_signal_data)signature and isolates the tests from any previousTrainingResultdict structure.If you want to simplify slightly, this line:
target_synthetic_data = MockTrainingResult(syn_data_5.copy()).synthetic_datacould just be:
target_synthetic_data = syn_data_5.copy()since the namedtuple adds no extra behavior here.
examples/ensemble_attack/configs/experiment_config.yaml (1)
1-107: Config structure matches the example pipeline; clarify a couple of comments and environment assumptionsThis new
experiment_config.yamllines up well with the example code:
pipeline.*flags drive which stages run inrun_attack.main.data_paths.*anddata_processing_config.*cover everythingrun_data_processingandprocess_split_dataneed, including the newpopulation_splits/challenge_splits.shadow_training.*defines all paths and knobs used byrun_shadow_model_trainingandrun_target_model_training, includingtarget_synthetic_data_path.metaclassifier.*(especiallydata_types_file_path,model_type, andmeta_classifier_model_name) matches howrun_metaclassifier_trainingand the BlendingPlusPlus tests consume config.Two small clarity improvements you might consider:
final_shadow_models_pathcomment (Lines 74-81)
The entries already interpolate${shadow_training.shadow_models_output_path}, so they are effectively full paths, not paths “relative to shadow_models_output_path” as the comment suggests. Rephrasing the comment (or removing “relative”) would avoid confusion.Cluster‑specific absolute paths (Lines 13, 27)
The/projects/midst-experiments/...paths are clearly tailored to your cluster. It may help future users to add a brief note in the header indicating that these should be customized to their environment before running the example.examples/ensemble_attack/run_metaclassifier_training.py (4)
67-76: Tighten validation of loaded target synthetic dataThe existence check on
target_model_synthetic_pathis good, but:
pd.read_csvwill never returnNone, so theassert target_synthetic is not Noneis effectively redundant.- If you want a stronger guard, consider validating that required columns are present and/or the DataFrame is non-empty instead.
For example:
- target_synthetic = pd.read_csv(target_model_synthetic_path) - - assert target_synthetic is not None, "Target model's synthetic data is missing." + target_synthetic = pd.read_csv(target_model_synthetic_path) + required_cols = df_meta_train.drop(columns=["trans_id", "account_id"]).columns + missing = set(required_cols) - set(target_synthetic.columns) + assert not missing, f"Target synthetic data missing expected columns: {missing}"
90-91: Be explicit aboutaccount_idexpectation or drop it defensivelyYou assert the presence of
"trans_id"but not"account_id", yet you unconditionally drop both:df_meta_train = df_meta_train.drop(columns=["trans_id", "account_id"]) df_meta_test = df_meta_test.drop(columns=["trans_id", "account_id"])If some datasets omit
account_id, this will raise aKeyError. Either assert its presence as well, or drop it witherrors="ignore".- df_meta_train = df_meta_train.drop(columns=["trans_id", "account_id"]) - df_meta_test = df_meta_test.drop(columns=["trans_id", "account_id"]) + df_meta_train = df_meta_train.drop(columns=["trans_id", "account_id"], errors="ignore") + df_meta_test = df_meta_test.drop(columns=["trans_id", "account_id"], errors="ignore")
118-122: Add a sanity check before pickling the trained meta-classifierIt’s possible (e.g., if tuning fails) that
blending_attacker.trained_modelends upNone, in which case you’d silently pickleNoneand only fail much later in testing. A small guard here would fail fast:- with open(model_path, "wb") as f: - pickle.dump(blending_attacker.trained_model, f) + assert blending_attacker.trained_model is not None, "Meta-classifier training did not produce a model." + with open(model_path, "wb") as f: + pickle.dump(blending_attacker.trained_model, f)
126-131: Clarify evaluation comment to reflect new dedicated testing scriptThe inline comment:
df_original_synthetic=target_synthetic, # For evaluation only, replace with actual target model during testing.is now slightly misleading, since the actual target-model testing is handled by
examples/ensemble_attack/test_attack_model.py. Consider rephrasing to make it clear this call is only for training-time evaluation and that the separate testing script handles real targets.examples/ensemble_attack/test_attack_model.py (3)
54-68: Harden challenge/synthetic data loading and column handlingA few small robustness points here:
- As with the training script, it can be helpful to assert that the CSV files exist and have expected columns before reading/using them, to fail fast with clearer messages.
- You assert
"trans_id"but not"account_id", yet you drop both. If some challenge datasets omitaccount_id,dropwill raise aKeyError.Suggested tweak:
- df_test = pd.read_csv(challenge_data_path) - y_test = pd.read_csv(challenge_label_path).to_numpy().squeeze() + assert challenge_data_path.exists(), f"Challenge data not found at {challenge_data_path}" + assert challenge_label_path.exists(), f"Challenge labels not found at {challenge_label_path}" + df_test = pd.read_csv(challenge_data_path) + y_test = pd.read_csv(challenge_label_path).to_numpy().squeeze() @@ - df_test = df_test.drop(columns=["trans_id", "account_id"]) + df_test = df_test.drop(columns=["trans_id", "account_id"], errors="ignore")You might also want to assert that
y_test.ndim == 1aftersqueeze()if downstream code assumes a 1D array.
61-63: Mirror existence checks for target synthetic dataUnlike in
run_metaclassifier_training, there’s no existence check fortarget_synthetic_pathhere. For symmetry and clearer errors:- target_synthetic_path = Path(config.target_model.target_synthetic_data_path) - target_synthetic = pd.read_csv(target_synthetic_path) + target_synthetic_path = Path(config.target_model.target_synthetic_data_path) + assert target_synthetic_path.exists(), f"Target synthetic data not found at {target_synthetic_path}" + target_synthetic = pd.read_csv(target_synthetic_path)
120-127: Output path and naming are fine; consider differentiating test vs valSaving to
config.target_model.attack_probabilities_result_pathwith the filename pattern*_val_pred_proba.npyworks, but the_val_in the name may be slightly confusing for test-time attack probabilities. If you expect both validation and test artifacts to coexist often, consider renaming this to_test_pred_proba.npyto prevent ambiguity.src/midst_toolkit/attacks/ensemble/blending.py (2)
31-60: Config-drivendata_types_file_pathis a good decoupling; consider better error messagingSwitching the constructor to take
data_types_file_path: Pathand loadingself.column_typesfrom JSON cleanly separates schema information from runtime data, which fits the rest of the refactor.You may want to add a small guard to provide clearer diagnostics if the file is missing or malformed, e.g.:
- with open(data_types_file_path, "r") as f: - self.column_types = json.load(f) + try: + with open(data_types_file_path, "r") as f: + self.column_types = json.load(f) + except FileNotFoundError as e: + raise FileNotFoundError(f"Data types file not found at {data_types_file_path}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in data types file at {data_types_file_path}") from e
202-209: Prediction precondition assertion is useful; minor type/typo nitsThe new assertion in
predict:assert self.trained_model is not None, ( "You must call .fit() before .predict() or provide a trained_model, " "or assign the trained model to the BlengingPlusPlus object." )nicely documents how the class can be used both with internal training and with an externally loaded model (as in the new testing script).
Two small polish points, if you touch this area again:
- There’s a typo in the message (
BlengingPlusPlus→BlendingPlusPlus).- The signature types
y_test: np.ndarray, but the docstring calls it optional and the code checksif y_test is not None:. For full consistency, you could update the type hint tonp.ndarray | None.Also applies to: 234-237
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
.gitignore(2 hunks)examples/ensemble_attack/configs/experiment_config.yaml(1 hunks)examples/ensemble_attack/configs/original_attack_config.yaml(5 hunks)examples/ensemble_attack/data_configs/trans.json(1 hunks)examples/ensemble_attack/real_data_collection.py(4 hunks)examples/ensemble_attack/run.sh(0 hunks)examples/ensemble_attack/run_attack.py(5 hunks)examples/ensemble_attack/run_metaclassifier_training.py(5 hunks)examples/ensemble_attack/run_shadow_model_training.py(2 hunks)examples/ensemble_attack/run_test.sh(1 hunks)examples/ensemble_attack/run_train.sh(1 hunks)examples/ensemble_attack/test_attack_model.py(1 hunks)src/midst_toolkit/attacks/ensemble/blending.py(7 hunks)src/midst_toolkit/attacks/ensemble/metric_utils.py(1 hunks)src/midst_toolkit/attacks/ensemble/process_split_data.py(1 hunks)src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py(7 hunks)src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py(1 hunks)src/midst_toolkit/attacks/ensemble/xgboost_tuner.py(1 hunks)tests/unit/attacks/ensemble/test_meta_classifier.py(10 hunks)tests/unit/attacks/ensemble/test_rmia.py(4 hunks)
💤 Files with no reviewable changes (1)
- examples/ensemble_attack/run.sh
🧰 Additional context used
🧬 Code graph analysis (11)
src/midst_toolkit/attacks/ensemble/xgboost_tuner.py (1)
src/midst_toolkit/attacks/ensemble/metric_utils.py (1)
get_tpr_at_fpr(7-28)
examples/ensemble_attack/run_metaclassifier_training.py (1)
src/midst_toolkit/attacks/ensemble/blending.py (1)
predict(202-256)
src/midst_toolkit/attacks/ensemble/metric_utils.py (2)
src/midst_toolkit/evaluation/privacy/mia_scoring.py (2)
TprAtFpr(216-267)TprFpr(159-213)tests/unit/evaluation/privacy/test_mia_metrics.py (1)
test_tpr_at_fpr_function_bad_ranges(24-37)
examples/ensemble_attack/test_attack_model.py (3)
examples/ensemble_attack/run_shadow_model_training.py (1)
run_shadow_model_training(82-136)src/midst_toolkit/attacks/ensemble/blending.py (3)
BlendingPlusPlus(26-256)MetaClassifierType(21-23)predict(202-256)src/midst_toolkit/attacks/ensemble/data_utils.py (1)
load_dataframe(31-52)
examples/ensemble_attack/configs/original_attack_config.yaml (2)
tests/integration/attacks/ensemble/test_shadow_model_training.py (2)
test_train_and_fine_tune_tabddpm(135-187)test_train_shadow_on_half_challenge_data(89-131)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (2)
fine_tune_tabddpm_and_synthesize(158-248)save_additional_tabddpm_config(36-76)
examples/ensemble_attack/data_configs/trans.json (2)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (2)
fine_tune_model(47-136)child_fine_tuning(243-339)tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)
src/midst_toolkit/attacks/ensemble/process_split_data.py (1)
tests/unit/attacks/ensemble/test_process_data_split.py (1)
test_process_split_data(17-67)
examples/ensemble_attack/run_attack.py (4)
src/midst_toolkit/common/random.py (1)
set_all_random_seeds(11-55)examples/ensemble_attack/run_shadow_model_training.py (2)
run_target_model_training(18-79)run_shadow_model_training(82-136)examples/ensemble_attack/run_metaclassifier_training.py (1)
run_metaclassifier_training(14-146)tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
cfg(31-33)
tests/unit/attacks/ensemble/test_rmia.py (1)
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (2)
Key(18-20)get_rmia_gower(23-86)
examples/ensemble_attack/configs/experiment_config.yaml (3)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)tests/integration/attacks/ensemble/test_shadow_model_training.py (2)
cfg(31-33)test_train_and_fine_tune_tabddpm(135-187)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
src/midst_toolkit/attacks/ensemble/blending.py (1)
src/midst_toolkit/attacks/ensemble/metric_utils.py (1)
get_tpr_at_fpr(7-28)
🪛 Ruff (0.14.5)
examples/ensemble_attack/test_attack_model.py
48-48: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
89-89: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: build
- GitHub Check: run-code-check
- GitHub Check: integration-tests
- GitHub Check: unit-tests
🔇 Additional comments (27)
examples/ensemble_attack/data_configs/trans.json (2)
9-12: ****The concern about
data_split_ratiosremoval is unfounded. The config classes definedata_split_ratioswith default values[0.7, 0.2, 0.1], and Pydantic will automatically use these defaults if the field is omitted from the JSON configuration. The code will not raiseAttributeErrororKeyErrorat runtime—it will simply use the defaults. All usages intrain.pyandclavaddpm_fine_tuning.pysafely access the field through config objects that have this field properly defined with fallback defaults.Likely an incorrect or invalid review comment.
31-31: Verification confirms scheduler implementation is correct.The
"cosine"scheduler is fully supported:
SchedulerType.COSINE = "cosine"is defined in the enum- Pydantic's
BaseModelautomatically converts the JSON string"cosine"to theSchedulerType.COSINEenum value- The scheduler parameter is properly propagated from
DiffusionConfig.schedulertoGaussianMultinomialDiffusion.__init__(scheduler_type: SchedulerType)- The
get_named_beta_schedule()function correctly handlesSchedulerType.COSINENo changes required.
examples/ensemble_attack/real_data_collection.py (2)
24-29: Good addition of experiment-specific attack types.The new enum entries support experimentation with different training data sizes and follow the existing naming conventions consistently.
141-176: Well-designed parameterization of data splits.The addition of optional
population_splitsandchallenge_splitsparameters with sensible defaults improves flexibility while maintaining backward compatibility. The default handling logic is clean and the automatic directory creation is a helpful addition.src/midst_toolkit/attacks/ensemble/process_split_data.py (1)
170-171: LGTM!Adding directory creation with
parents=Trueandexist_ok=Trueis a defensive best practice that ensures the output path exists before writing data files.src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (1)
23-86: LGTM! Cleaner API with DataFrame lists.The refactoring from dict-based to list-based model data simplifies the API and improves clarity. Direct iteration over DataFrames is more intuitive than key-based access.
src/midst_toolkit/attacks/ensemble/metric_utils.py (1)
17-20: LGTM! Clearer documentation.The updated docstring better explains that predictions are confidence values in [0,1] representing membership probability, improving clarity for users of this function.
src/midst_toolkit/attacks/ensemble/xgboost_tuner.py (1)
12-12: LGTM! Import path refactored to metric_utils.Moving
get_tpr_at_fprtometric_utilsis a more appropriate module organization for metric-related utilities, and aligns with similar changes across the ensemble attack module.examples/ensemble_attack/run_attack.py (2)
16-16: Deterministic seeding hook looks goodUsing
set_all_random_seeds(seed=config.random_seed)at the start ofmainis a solid addition and aligns with the newrandom_seedfield inexperiment_config.yaml. This will help keep runs reproducible across all three pipeline stages.Also applies to: 58-58
28-37: Passingpopulation_splits/challenge_splitsthrough is aligned with the new configWiring
population_splitsandchallenge_splitsfromconfig.data_processing_configintocollect_population_data_ensemblematches the newexperiment_config.yamlstructure and keeps the example flexible for different split configurations.tests/unit/attacks/ensemble/test_rmia.py (1)
155-161: Tests now correctly exerciseget_rmia_gowerwith a list of synthetic DataFramesThe updated tests that build:
shadow_synthetic_listfrombase_data["model_data"][Key.TRAINED_RESULTS.value], andsynthetic_data_listfrom the same source (and fromKey.FINE_TUNED_RESULTS.valuefor the missing categorical case),and then pass these lists to
get_rmia_gower(model_data=...)are well aligned with the new API (model_data: list[pd.DataFrame]).The assertions still verify:
- correct call counts,
- that the ID column is dropped before distance computation, and
- that sampling uses
random_state=base_data["random_seed"].This gives good coverage of the refactored implementation.
Also applies to: 198-206, 226-232
tests/unit/attacks/ensemble/test_meta_classifier.py (6)
31-42: Mock config now matches new metaclassifier settingsAdding
data_types_file_pathandmeta_classifier_model_nameundermetaclassifierinmock_config_with_json_pathkeeps the test config in sync with the realexperiment_config.yamlshape and ensures BlendingPlusPlus can be initialized without missing fields.
47-86: Including the ID column insample_dataframesis consistent withdata_types.jsonexpectationsExtending
sample_dataframesso that all frames includeid_col(matchingMOCK_COLUMN_TYPES_CONTENT["id_column_name"]) is the right move. It mirrors how the real pipeline uses an explicit ID column for RMIA/DOMIAS/Gower feature calculation and for metaclassifier training/evaluation.
89-122: BlendingPlusPlus initialization tests correctly exercisedata_types_file_pathhandlingThe updated
test_init_successandtest_init_invalid_type_raises_error:
- Pass
data_types_file_path=mock_config_with_json_path.metaclassifier.data_types_file_path,- Verify that
openis called once with that path in read mode, and- Confirm that the loaded
column_typesandmeta_classifier_typeare as expected.This matches the new constructor contract for BlendingPlusPlus and provides good regression coverage around the JSON schema.
141-231: Meta‑feature preparation tests are aligned with the new column‑types and RMIA wiringIn both
_prepare_meta_featurestests, passingdata_types_file_pathinto BlendingPlusPlus and then:
- using
MOCK_COLUMN_TYPES_CONTENT["categorical"],["numerical"], and"id_column_name", and- asserting that
calculate_rmia_signalsreceivesdf_input,shadow_data_collection,categorical_column_names,id_column_name, andid_column_datanicely validate that the new column‑types–driven code path is exercised correctly.
235-302: Fit‑path tests for LR/XGB correctly incorporatedata_types_file_pathBoth
test_fit_logistic_regressionandtest_fit_xgboostnow initialize BlendingPlusPlus withdata_types_file_pathand confirm:
_prepare_meta_featuresis invoked once, and- the right model type is constructed and trained, with hyperparameters pulled from
mock_config_with_json_path.This keeps the tests in sync with the updated API without changing their behavioral intent.
303-364: Predict‑path tests correctly use the new initialization signatureThe predict‑related tests now pass
data_types_file_pathinto BlendingPlusPlus in both the “not fit yet” and full predict flow cases, while still asserting:
- an
AssertionErrorwhenpredictis called beforefit, and- correct probability extraction and TPR@FPR computation in the happy path.
These updates maintain strong coverage of the predict API after the constructor change.
examples/ensemble_attack/run_metaclassifier_training.py (4)
6-6: Pandas import aligns with new CSV-based synthetic loadingThe added
pandasimport is appropriate for reading the target model’s synthetic data from CSV; no issues here.
14-28: Decoupling from TrainingResult viatarget_model_synthetic_pathlooks goodThe new
target_model_synthetic_path: Pathargument and updated docstring correctly narrow the training dependency to just the target’s synthetic data, which matches the PR objective of removing the need for the fullTrainingResult.
96-103: BlendingPlusPlus initialization withdata_types_file_pathis consistentPassing
data_types_file_path=Path(config.metaclassifier.data_types_file_path)matches the updatedBlendingPlusPlusconstructor and centralizes column-type config in the JSON file; this wiring looks correct.
136-143: Evaluation output path and naming are coherentCreating
attack_evaluation_result_pathand saving to a deterministic filename based onmodel_type(*_val_pred_proba.npy) is a nice improvement over timestamped paths and should simplify downstream consumption.examples/ensemble_attack/test_attack_model.py (4)
18-35: Hydra entrypoint and logging of target model ID look goodUsing
@hydra.mainwithexperiment_configand loggingconfig.target_model.target_model_idon start provides a clear entrypoint and traceability for which target model is under attack.
95-118: BlendingPlusPlus construction and predict usage are consistent with the API
- Passing
data_types_file_path=Path(config.metaclassifier.data_types_file_path)matches the updated constructor.- Assigning
blending_attacker.trained_model = trained_mataclassifier_modelaligns with the new assertion inpredict()(“provide a trained_model, or assign the trained model to the BlengingPlusPlus object”).- The
predictcall wiresdf_test,df_original_synthetic=target_synthetic,df_reference, andid_column_data=test_trans_idscoherently with the training-time usage.This section looks correct.
133-134:__main__guard correctly delegates to Hydra-decorated entrypointCalling
run_metaclassifier_testing()under theif __name__ == "__main__":guard is the standard Hydra pattern and should work as intended.
77-90: Pickle concern does not apply to this closed research pipelineThe review correctly identifies that this code is "fine in a closed research pipeline," and that is precisely the context here. The pickle files are generated internally by
run_shadow_model_training(config)during execution—not loaded from untrusted external sources. Config paths derive from developer-controlled YAML files inexamples/ensemble_attack/configs/. While Hydra allows CLI overrides, this is standard for research code where the researcher controls the command line and config values. The assumption that shadow model files are trusted is valid in this context.Likely an incorrect or invalid review comment.
src/midst_toolkit/attacks/ensemble/blending.py (2)
6-7: ImportingPathmatches newdata_types_file_pathusageAdding
from pathlib import Pathis consistent with the constructor’s newdata_types_file_path: Pathparameter.
82-86: RMIA integration uses target synthetic data consistentlyThe
_prepare_meta_featuresdocstring and thecalculate_rmia_signalscall now both refer to synthetic data asdf_synthetic/target_synthetic_data, which matches how the callers (fitandpredict) pass in the diffusion model’s synthetic data. This should keep RMIA features aligned with the rest of the meta-features.Also applies to: 100-106
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py
Outdated
Show resolved
Hide resolved
sarakodeiri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great to me. Added very minor comments / questions.
* Added testing several targets on multiple gpus * Added a comment
lotif
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM apart from one small thing. Thanks for addressing the comments!
| df_test = df_test.drop(columns=["trans_id", "account_id"]) | ||
| with open(Path(config.metaclassifier.data_types_file_path), "r") as f: | ||
| column_types = json.load(f) | ||
| id_column_name = column_types["id_column_name"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for finding this!
PR Type
[Feature | Fix ]
Short Description
Clickup Ticket(s): Link
Added a script in the Ensemble attack example to facilitate testing on target models in the experiment setup. This PR also removes the dependency of the meta classifier pipeline on the whole target's TrainingResult object, as only its
synthetic_datais actually required.Tests Added
Changes are made to the existing tests.