File tree Expand file tree Collapse file tree 1 file changed +2
-6
lines changed
Expand file tree Collapse file tree 1 file changed +2
-6
lines changed Original file line number Diff line number Diff line change @@ -396,11 +396,9 @@ def train_llm(self):
396396 # LoRA training on y_llm_ttrain --> Testing on y_llm_ttest
397397 x_llm_ttrain_b , scores_ttrain_b = (
398398 get_batches (self .x_llm_ttrain , batch_size = self .batch_size , dtype = int ),
399- #get_batches(self.attn_llm_ttrain, batch_size=self.batch_size, dtype=int),
400399 get_batches (self .y_ttrain , batch_size = self .batch_size , dtype = float )
401400 )
402401
403- #x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
404402 if self .llm_key == 'prosst' :
405403 y_llm_ttest = self .llm_inference_function (
406404 xs = self .x_llm_ttest ,
@@ -457,8 +455,7 @@ def train_llm(self):
457455 self .llm_attention_mask ,
458456 self .structure_input_ids ,
459457 n_epochs = 50 ,
460- device = self .device ,
461- #seed=self.seed
458+ device = self .device
462459 )
463460 y_llm_lora_ttrain = self .llm_inference_function (
464461 xs = self .x_llm_ttrain ,
@@ -486,8 +483,7 @@ def train_llm(self):
486483 self .llm_model ,
487484 self .llm_optimizer ,
488485 n_epochs = 5 ,
489- device = self .device ,
490- #seed=self.seed
486+ device = self .device
491487 )
492488 y_llm_lora_ttrain = self .llm_inference_function (
493489 xs = x_llm_ttrain_b ,
You can’t perform that action at this time.
0 commit comments