@@ -164,17 +164,14 @@ def get_opt_param(params):
164164
165165 def get_data_loader (_training_data , _validation_data , _training_params ):
166166 def get_dataloader_and_buffer (_data , _params ):
167- _sampler = get_sampler_from_params (_data , _params )
168- if _sampler is None :
169- log .warning (
170- "Sampler not specified!"
171- ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
167+ # _sampler = get_sampler_from_params(_data, _params)
168+ # if _sampler is None:
169+ # log.warning(
170+ # "Sampler not specified!"
171+ # ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
172172 _dataloader = DataLoader (
173173 _data ,
174- batch_sampler = paddle .io .BatchSampler (
175- sampler = _sampler ,
176- drop_last = False ,
177- ),
174+ batch_size = 1 ,
178175 num_workers = NUM_WORKERS
179176 if dist .is_available ()
180177 else 0 , # setting to 0 diverges the behavior of its iterator; should be >=1
@@ -325,17 +322,18 @@ def get_lr(lr_params):
325322 self .validation_data ,
326323 self .valid_numb_batch ,
327324 ) = get_data_loader (training_data , validation_data , training_params )
328- training_data .print_summary (
329- "training" ,
330- to_numpy_array (self .training_dataloader .batch_sampler .sampler .weights ),
331- )
332- if validation_data is not None :
333- validation_data .print_summary (
334- "validation" ,
335- to_numpy_array (
336- self .validation_dataloader .batch_sampler .sampler .weights
337- ),
338- )
325+ # no sampler, do not need print!
326+ # training_data.print_summary(
327+ # "training",
328+ # to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
329+ # )
330+ # if validation_data is not None:
331+ # validation_data.print_summary(
332+ # "validation",
333+ # to_numpy_array(
334+ # self.validation_dataloader.batch_sampler.sampler.weights
335+ # ),
336+ # )
339337 else :
340338 (
341339 self .training_dataloader ,
@@ -370,27 +368,27 @@ def get_lr(lr_params):
370368 validation_data [model_key ],
371369 training_params ["data_dict" ][model_key ],
372370 )
373-
374- training_data [model_key ].print_summary (
375- f"training in { model_key } " ,
376- to_numpy_array (
377- self .training_dataloader [
378- model_key
379- ].batch_sampler .sampler .weights
380- ),
381- )
382- if (
383- validation_data is not None
384- and validation_data [model_key ] is not None
385- ):
386- validation_data [model_key ].print_summary (
387- f"validation in { model_key } " ,
388- to_numpy_array (
389- self .validation_dataloader [
390- model_key
391- ].batch_sampler .sampler .weights
392- ),
393- )
371+ # no sampler, do not need print!
372+ # training_data[model_key].print_summary(
373+ # f"training in {model_key}",
374+ # to_numpy_array(
375+ # self.training_dataloader[
376+ # model_key
377+ # ].batch_sampler.sampler.weights
378+ # ),
379+ # )
380+ # if (
381+ # validation_data is not None
382+ # and validation_data[model_key] is not None
383+ # ):
384+ # validation_data[model_key].print_summary(
385+ # f"validation in {model_key}",
386+ # to_numpy_array(
387+ # self.validation_dataloader[
388+ # model_key
389+ # ].batch_sampler.sampler.weights
390+ # ),
391+ # )
394392
395393 # Learning rate
396394 self .warmup_steps = training_params .get ("warmup_steps" , 0 )
@@ -706,7 +704,7 @@ def run(self) -> None:
706704 fout1 = open (record_file , mode = "w" , buffering = 1 )
707705 log .info ("Start to train %d steps." , self .num_steps )
708706 if dist .is_available () and dist .is_initialized ():
709- log .info (f"Rank: { dist .get_rank ()} /{ dist .get_world_size ()} " )
707+ log .info (f"xxx Rank: { dist .get_rank ()} /{ dist .get_world_size ()} " )
710708 if self .enable_tensorboard :
711709 from tensorboardX import (
712710 SummaryWriter ,
@@ -755,50 +753,54 @@ def step(_step_id, task_key="Default") -> None:
755753 if self .world_size > 1
756754 else contextlib .nullcontext
757755 )
756+
757+ # with nvprof_context(enable_profiling, "Forward pass"):
758+ log_dict = {}
759+
760+ input_dict = {
761+ "spin" : None ,
762+ "fparam" : None ,
763+ "aparam" : None ,
764+ }
765+ label_dict = {
766+ "find_box" : 1.0 ,
767+ "find_coord" : 1.0 ,
768+ "find_numb_copy" : 0.0 ,
769+ "find_energy" : 1.0 ,
770+ "find_force" : 1.0 ,
771+ "find_virial" : 0.0 ,
772+ }
773+ for k in ["atype" , "box" , "coord" ]:
774+ input_dict [k ] = paddle .load (f"./input_{ k } .pd" )
775+ for k in ["energy" , "force" , "natoms" , "numb_copy" , "virial" ]:
776+ label_dict [k ] = paddle .load (f"./label_{ k } .pd" )
777+
778+ for __key in ('coord' , 'atype' , 'box' ):
779+ input_dict [__key ] = dist .shard_tensor (input_dict [__key ], mesh = dist .get_mesh (), placements = [dist .Shard (0 )])
780+ for __key , _ in label_dict .items ():
781+ if isinstance (label_dict [__key ], paddle .Tensor ):
782+ label_dict [__key ] = dist .shard_tensor (label_dict [__key ], mesh = dist .get_mesh (), placements = [dist .Shard (0 )])
758783
759- # with sync_context():
760- # with nvprof_context(enable_profiling, "Forward pass"):
761- # model_pred, loss, more_loss = self.wrapper(
762- # **input_dict,
763- # cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
764- # label=label_dict,
765- # task_key=task_key,
766- # )
767-
768- # with nvprof_context(enable_profiling, "Backward pass"):
769- # loss.backward()
770-
771- # if self.world_size > 1:
772- # # fuse + allreduce manually before optimization if use DDP + no_sync
773- # # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
774- # hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
775-
776- with nvprof_context (enable_profiling , "Forward pass" ):
777- for __key in ('coord' , 'atype' , 'box' ):
778- input_dict [__key ] = dist .shard_tensor (input_dict [__key ], mesh = dist .get_mesh (), placements = [dist .Shard (0 )])
779- for __key , _ in label_dict .items ():
780- if isinstance (label_dict [__key ], paddle .Tensor ):
781- label_dict [__key ] = dist .shard_tensor (label_dict [__key ], mesh = dist .get_mesh (), placements = [dist .Shard (0 )])
782- model_pred , loss , more_loss = self .wrapper (
783- ** input_dict ,
784- cur_lr = paddle .full ([], pref_lr , DEFAULT_PRECISION ),
785- label = label_dict ,
786- task_key = task_key ,
787- )
784+ model_pred , loss , more_loss = self .wrapper (
785+ ** input_dict ,
786+ cur_lr = paddle .full ([], pref_lr , DEFAULT_PRECISION ),
787+ label = label_dict ,
788+ task_key = task_key ,
789+ )
788790
789- with nvprof_context (enable_profiling , "Backward pass" ):
790- loss .backward ()
791+ # with nvprof_context(enable_profiling, "Backward pass"):
792+ loss .backward ()
791793
792794 if self .gradient_max_norm > 0.0 :
793- with nvprof_context (enable_profiling , "Gradient clip" ):
794- paddle .nn .utils .clip_grad_norm_ (
795- self .wrapper .parameters (),
796- self .gradient_max_norm ,
797- error_if_nonfinite = True ,
798- )
795+ # with nvprof_context(enable_profiling, "Gradient clip"):
796+ paddle .nn .utils .clip_grad_norm_ (
797+ self .wrapper .parameters (),
798+ self .gradient_max_norm ,
799+ error_if_nonfinite = True ,
800+ )
799801
800- with nvprof_context (enable_profiling , "Adam update" ):
801- self .optimizer .step ()
802+ # with nvprof_context(enable_profiling, "Adam update"):
803+ self .optimizer .step ()
802804 self .scheduler .step ()
803805
804806 else :
@@ -856,7 +858,9 @@ def log_loss_valid(_task_key="Default"):
856858
857859 if not self .multi_task :
858860 train_results = log_loss_train (loss , more_loss )
859- valid_results = log_loss_valid ()
861+ # valid_results = log_loss_valid()
862+ # no run valid!
863+ valid_results = None
860864 if self .rank == 0 :
861865 log .info (
862866 format_training_message_per_task (
@@ -938,39 +942,39 @@ def log_loss_valid(_task_key="Default"):
938942 ):
939943 self .total_train_time += train_time
940944
941- if fout :
942- if self .lcurve_should_print_header :
943- self .print_header (fout , train_results , valid_results )
944- self .lcurve_should_print_header = False
945- self .print_on_training (
946- fout , display_step_id , cur_lr , train_results , valid_results
947- )
948-
949- if (
950- ((_step_id + 1 ) % self .save_freq == 0 and _step_id != self .start_step )
951- or (_step_id + 1 ) == self .num_steps
952- ) and (self .rank == 0 or dist .get_rank () == 0 ):
953- # Handle the case if rank 0 aborted and re-assigned
954- self .latest_model = Path (self .save_ckpt + f"-{ _step_id + 1 } .pd" )
955- self .save_model (self .latest_model , lr = cur_lr , step = _step_id )
956- log .info (f"Saved model to { self .latest_model } " )
957- symlink_prefix_files (self .latest_model .stem , self .save_ckpt )
958- with open ("checkpoint" , "w" ) as f :
959- f .write (str (self .latest_model ))
945+ # if fout:
946+ # if self.lcurve_should_print_header:
947+ # self.print_header(fout, train_results, valid_results)
948+ # self.lcurve_should_print_header = False
949+ # self.print_on_training(
950+ # fout, display_step_id, cur_lr, train_results, valid_results
951+ # )
952+
953+ # if (
954+ # ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
955+ # or (_step_id + 1) == self.num_steps
956+ # ) and (self.rank == 0 or dist.get_rank() == 0):
957+ # # Handle the case if rank 0 aborted and re-assigned
958+ # self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pd")
959+ # self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
960+ # log.info(f"Saved model to {self.latest_model}")
961+ # symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
962+ # with open("checkpoint", "w") as f:
963+ # f.write(str(self.latest_model))
960964
961965 # tensorboard
962- if self .enable_tensorboard and (
963- display_step_id % self .tensorboard_freq == 0 or display_step_id == 1
964- ):
965- writer .add_scalar (f"{ task_key } /lr" , cur_lr , display_step_id )
966- writer .add_scalar (f"{ task_key } /loss" , loss .item (), display_step_id )
967- for item in more_loss :
968- writer .add_scalar (
969- f"{ task_key } /{ item } " , more_loss [item ].item (), display_step_id
970- )
971-
972- if enable_profiling :
973- core .nvprof_nvtx_pop ()
966+ # if self.enable_tensorboard and (
967+ # display_step_id % self.tensorboard_freq == 0 or display_step_id == 1
968+ # ):
969+ # writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
970+ # writer.add_scalar(f"{task_key}/loss", loss.item(), display_step_id)
971+ # for item in more_loss:
972+ # writer.add_scalar(
973+ # f"{task_key}/{item}", more_loss[item].item(), display_step_id
974+ # )
975+
976+ # if enable_profiling:
977+ # core.nvprof_nvtx_pop()
974978
975979 self .wrapper .train ()
976980 self .t0 = time .time ()
0 commit comments