@@ -47,7 +47,7 @@ def mixed_loss(
4747 noise : list [list [float ]],
4848 parallel_batch : int ,
4949 additional_timestep : int ,
50- timestep : Tensor ,
50+ timestep : int ,
5151) -> Tensor :
5252 """
5353 Compute the loss function for the Tartan Federer classifier.
@@ -70,22 +70,31 @@ def mixed_loss(
7070 categorical_features = features [:, diffusion_model .num_numerical_features :]
7171
7272 noise_tensor = torch .tensor (noise , device = device , dtype = torch .float )
73+ # Here we're repeating the noise tensor for each sample in the dataset so that each point gets the same set of
74+ # different noise values. This happens because parallel_batch is set to num_noise_per_time_step in preceding
75+ # calling functions
7376 batch_noise = noise_tensor .repeat (batch_size , 1 )
7477
7578 # TODO: Handle the categorical features more effectively. Because the numerical features were originally ignored
7679 # in the diffusion model and thus are ignored in this attack construction.
7780 numerical_features = numerical_features .repeat_interleave (parallel_batch , dim = 0 )
7881 categorical_features = categorical_features .repeat_interleave (parallel_batch , dim = 0 )
7982
83+ # Note that the shape here is not equivalent to batch_size after the interleave
84+ zero_timestep = torch .zeros (numerical_features .shape [0 ], device = DEVICE ).long ()
85+ current_timestep = zero_timestep + timestep
86+
8087 # forward x_num_t with (t + additional_t) timesteps
8188 # TODO: Expand this to also include categorical features
8289 numerical_features_t = diffusion_model .gaussian_q_sample (
83- numerical_features , timestep + additional_timestep , noise = batch_noise
90+ numerical_features , current_timestep + additional_timestep , noise = batch_noise
8491 )
8592
8693 # predict noises with t timesteps
87- predicted_noise = diffusion_model ._denoise_fn (numerical_features_t , timestep , ** outputs )
88- current_loss = diffusion_model ._gaussian_loss (predicted_noise , batch_noise , batch_noise , timestep , batch_noise )
94+ predicted_noise = diffusion_model ._denoise_fn (numerical_features_t , current_timestep , ** outputs )
95+ current_loss = diffusion_model ._gaussian_loss (
96+ predicted_noise , batch_noise , batch_noise , current_timestep , batch_noise
97+ )
8998 return current_loss .reshape (- 1 , parallel_batch )
9099
91100
@@ -120,7 +129,7 @@ def make_dataset_from_df_with_loaded(
120129 )
121130
122131 numerical_features = {"train" : data [numerical_column_names ].values .astype (np .float32 )}
123- categorical_features = {"train" : data [categorical_column_names ].values . astype ( np .float32 )}
132+ categorical_features = {"train" : data [categorical_column_names ].to_numpy ( dtype = np .str_ )}
124133 targets = {"train" : data [[table_metadata .target_column_name ]].values .astype (np .float32 )}
125134
126135 if len (categorical_column_names ) > 0 :
@@ -307,16 +316,14 @@ def get_score(
307316
308317 with torch .no_grad ():
309318 # get loss here
310- current_timestep , _ = diffusion_model .sample_time (batch_size , DEVICE )
311-
312319 loss = mixed_loss (
313320 diffusion_model = diffusion_model ,
314321 features = features ,
315322 outputs = outputs ,
316323 noise = input_noise ,
317324 parallel_batch = parallel_batch ,
318325 additional_timestep = additional_timestep ,
319- timestep = current_timestep * 0 + timestep ,
326+ timestep = timestep ,
320327 )
321328
322329 # TODO: Should we be summing this loss or something? We're only going to get the last loss in the iteration.
@@ -347,7 +354,11 @@ def filter_dataframe(
347354
348355
349356def prepare_dataframe (
350- model_dir : Path , merged_data : pd .DataFrame , columns_for_deduplication : list [str ], samples_per_train_model : int
357+ model_dir : Path ,
358+ merged_data : pd .DataFrame ,
359+ columns_for_deduplication : list [str ],
360+ samples_per_train_model : int ,
361+ mia_dataset_name : str ,
351362) -> pd .DataFrame :
352363 """
353364 Prepare the dataframes for Tartan Federer Attack Classifier training.
@@ -358,6 +369,7 @@ def prepare_dataframe(
358369 merged_data: Dataframe constructed with the ``prepare_data_for_attack`` function.
359370 columns_for_deduplication: Columns to use in filtering the dataframes.
360371 samples_per_train_model: Number of samples to draw from the prepared data for model training.
372+ mia_dataset_name: Name of the MIA dataset file to be saved.
361373
362374 Returns:
363375 Filtered dataframe reading for classifier training (or testing)
@@ -370,7 +382,7 @@ def prepare_dataframe(
370382 data_from_train = raw_data .sample (samples_per_train_model )
371383
372384 df_data = pd .concat ([data_exclusive , data_from_train ], ignore_index = True )
373- df_data .to_csv (model_dir / "data_for_training_MIA.csv" , index = False )
385+ df_data .to_csv (model_dir / mia_dataset_name , index = False )
374386
375387 return filter_dataframe (merged_data , df_data , columns_for_deduplication )
376388
@@ -421,14 +433,14 @@ def train_tartan_federer_attack_classifier(
421433 df_train_merge , _ , _ = prepare_data_for_attack (
422434 model_indices = train_indices ,
423435 model_type = model_type ,
424- models_base_dir = Path ( "/projects/midst-experiments/tabddpm_midst_toolkit/train/" ) ,
436+ models_base_dir = model_data_dir ,
425437 columns_for_deduplication = columns_for_deduplication ,
426438 )
427439
428440 df_test_merge , _ , _ = prepare_data_for_attack (
429441 model_indices = val_indices ,
430442 model_type = model_type ,
431- models_base_dir = Path ( "/projects/aieng/midst_competition/data/tabddpm" ) ,
443+ models_base_dir = model_data_dir ,
432444 columns_for_deduplication = columns_for_deduplication ,
433445 )
434446
@@ -461,12 +473,20 @@ def train_tartan_federer_attack_classifier(
461473
462474 if model_number in train_indices :
463475 df_train_merge = prepare_dataframe (
464- model_dir , df_train_merge , columns_for_deduplication , samples_per_train_model
476+ model_dir ,
477+ df_train_merge ,
478+ columns_for_deduplication ,
479+ samples_per_train_model ,
480+ "data_for_training_MIA.csv" ,
465481 )
466482
467483 elif model_number in val_indices :
468484 df_test_merge = prepare_dataframe (
469- model_dir , df_test_merge , columns_for_deduplication , sample_per_val_model
485+ model_dir ,
486+ df_test_merge ,
487+ columns_for_deduplication ,
488+ sample_per_val_model ,
489+ "data_for_validating_MIA.csv" ,
470490 )
471491
472492 timestep_count = 0
0 commit comments