diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 5941d2c8d0..cd6eb696c9 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -36,6 +36,11 @@ save_dp_model, traverse_model_dict, ) +from .training_utils import ( + compute_total_numb_batch, + resolve_model_prob, + resolve_model_prob_from_epochs, +) __all__ = [ "AtomExcludeMask", @@ -49,6 +54,7 @@ "aggregate", "build_multiple_neighbor_list", "build_neighbor_list", + "compute_total_numb_batch", "extend_coord_with_ghosts", "get_graph_index", "get_multiple_nlist_key", @@ -60,6 +66,8 @@ "nlist_distinguish_types", "normalize_coord", "phys2inter", + "resolve_model_prob", + "resolve_model_prob_from_epochs", "save_dp_model", "to_face_distance", "traverse_model_dict", diff --git a/deepmd/dpmodel/utils/training_utils.py b/deepmd/dpmodel/utils/training_utils.py new file mode 100644 index 0000000000..72dfda930a --- /dev/null +++ b/deepmd/dpmodel/utils/training_utils.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from collections.abc import ( + Iterable, +) + +import numpy as np + +log = logging.getLogger(__name__) + + +def compute_total_numb_batch( + numb_batches: Iterable[int], + sampler_weights: np.ndarray, +) -> int: + """Compute total number of batches considering sampler weights. + + Parameters + ---------- + numb_batches : Iterable[int] + Number of batches for each data system. + sampler_weights : np.ndarray + Sampling weights for each data system. + + Returns + ------- + int + Total number of batches. + + Raises + ------ + ValueError + If input validation fails. + """ + weights = np.asarray(sampler_weights, dtype=np.float64) + if weights.ndim != 1: + raise ValueError("Sampler weights must be 1D.") + if weights.size == 0: + raise ValueError("Sampler weights are empty.") + if not np.all(np.isfinite(weights)): + raise ValueError("Sampler weights must be finite.") + if np.any(weights < 0.0): + raise ValueError("Sampler weights must be non-negative.") + weight_sum = float(np.sum(weights)) + if weight_sum <= 0.0: + raise ValueError("Sampler weights must sum to a positive value.") + probs = weights / weight_sum + nbatches = np.asarray(numb_batches, dtype=np.float64) + if nbatches.ndim != 1: + raise ValueError("Number of batches must be 1D.") + if nbatches.size == 0: + raise ValueError("Number of batches is empty.") + if not np.all(np.isfinite(nbatches)): + raise ValueError("Number of batches must be finite.") + if np.any(nbatches < 0.0): + raise ValueError("Number of batches must be non-negative.") + if nbatches.shape[0] != probs.shape[0]: + raise ValueError("Number of batches and sampler weights must match.") + valid = probs > 0.0 + if not np.any(valid): + raise ValueError( + "Sampler probabilities must contain at least one positive entry." + ) + return int(np.ceil(np.max(nbatches[valid] / probs[valid]))) + + +def resolve_model_prob( + model_keys: list[str], + model_prob_config: dict[str, float] | None, + model_training_data: dict[str, object], + rank: int = 0, +) -> np.ndarray: + """Resolve model training probability for multi-task training. + + Parameters + ---------- + model_keys : list[str] + List of model keys. + model_prob_config : dict[str, float] | None + User-specified model probabilities. If None, use data size. + model_training_data : dict[str, object] + Training data for each model. + rank : int, optional + Process rank for distributed training, by default 0. + + Returns + ------- + np.ndarray + Normalized model probabilities. + + Raises + ------ + ValueError + If input validation fails. + """ + model_prob = np.zeros(len(model_keys), dtype=np.float64) + if model_prob_config: + missing = [k for k in model_keys if k not in model_prob_config] + if missing: + raise ValueError( + f"training.model_prob must specify all tasks; missing: {missing}" + ) + for ii, model_key in enumerate(model_keys): + if model_key in model_prob_config: + model_prob[ii] = float(model_prob_config[model_key]) + else: + if rank == 0: + log.info( + "training.model_prob is not set or empty; defaulting to the " + "number of systems per task." + ) + for ii, model_key in enumerate(model_keys): + model_prob[ii] = float(len(model_training_data[model_key])) + if not np.all(np.isfinite(model_prob)): + raise ValueError("Model prob must be finite.") + if np.any(model_prob < 0.0): + raise ValueError("Model prob must be non-negative.") + sum_prob = float(np.sum(model_prob)) + if sum_prob <= 0.0: + raise ValueError("Sum of model prob must be larger than 0!") + return model_prob / sum_prob + + +def resolve_model_prob_from_epochs( + model_keys: list[str], + num_epoch_dict_config: dict[str, float], + per_task_total: np.ndarray, +) -> tuple[np.ndarray, int, dict[str, float]]: + """Resolve model probability and training steps from epoch configuration. + + Parameters + ---------- + model_keys : list[str] + List of model keys. + num_epoch_dict_config : dict[str, float] + Target epochs for each task. + per_task_total : np.ndarray + Total batches per task. + + Returns + ------- + tuple[np.ndarray, int, dict[str, float]] + Model probabilities, total training steps, and per-task steps. + + Raises + ------ + ValueError + If input validation fails. + """ + if not num_epoch_dict_config: + raise ValueError("training.num_epoch_dict must be set for multi-task epochs.") + missing = [k for k in model_keys if k not in num_epoch_dict_config] + if missing: + raise ValueError( + f"training.num_epoch_dict must specify all tasks; missing: {missing}" + ) + epoch_targets = np.zeros(len(model_keys), dtype=np.float64) + for ii, model_key in enumerate(model_keys): + epoch_value = num_epoch_dict_config[model_key] + if epoch_value is None: + raise ValueError( + f"training.num_epoch_dict['{model_key}'] must be positive." + ) + epoch_value = float(epoch_value) + if not np.isfinite(epoch_value) or epoch_value <= 0.0: + raise ValueError( + f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}." + ) + epoch_targets[ii] = epoch_value + per_task_total = np.asarray(per_task_total, dtype=np.float64) + if per_task_total.ndim != 1: + raise ValueError("Per-task total batches must be 1D.") + if per_task_total.shape[0] != epoch_targets.shape[0]: + raise ValueError("Per-task totals and epoch targets must match.") + if not np.all(np.isfinite(per_task_total)): + raise ValueError("Per-task total batches must be finite.") + if np.any(per_task_total <= 0.0): + raise ValueError("Per-task total batches must be positive.") + per_task_steps = per_task_total * epoch_targets + total_target_steps = float(np.sum(per_task_steps)) + if total_target_steps <= 0.0: + raise ValueError("Sum of target steps must be positive.") + model_prob = per_task_steps / total_target_steps + num_steps = int(np.ceil(total_target_steps)) + per_task_steps_map = { + model_key: float(per_task_steps[ii]) for ii, model_key in enumerate(model_keys) + } + return model_prob, num_steps, per_task_steps_map diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index dd0fbdc94b..1c76787e45 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -30,6 +30,11 @@ from deepmd.common import ( symlink_prefix_files, ) +from deepmd.dpmodel.utils import ( + compute_total_numb_batch, + resolve_model_prob, + resolve_model_prob_from_epochs, +) from deepmd.dpmodel.utils.learning_rate import ( BaseLR, ) @@ -130,9 +135,12 @@ def __init__( else 1 ) self.num_model = len(self.model_keys) + self.model_prob = None # Iteration config - self.num_steps = training_params["numb_steps"] + self.num_steps = training_params.get("numb_steps") + self.num_epoch = training_params.get("num_epoch") + self.num_epoch_dict = training_params.get("num_epoch_dict") self.acc_freq: int = training_params.get( "acc_freq", 1 ) # gradient accumulation steps @@ -386,6 +394,75 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ), ) + per_task_total = [] + if not self.multi_task: + sampler_weights = to_numpy_array( + self.training_dataloader.batch_sampler.sampler.weights + ) + total_numb_batch = compute_total_numb_batch( + training_data.index, + sampler_weights, + ) + if self.num_steps is None: + if self.num_epoch is None: + raise ValueError( + "Either training.numb_steps or training.num_epoch must be set." + ) + if self.num_epoch <= 0: + raise ValueError("training.num_epoch must be positive.") + if total_numb_batch <= 0: + raise ValueError( + "Total number of training batches must be positive." + ) + self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch)) + log.info( + "Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.", + self.num_steps, + self.num_epoch, + total_numb_batch, + ) + else: + for model_key in self.model_keys: + sampler_weights = to_numpy_array( + self.training_dataloader[model_key].batch_sampler.sampler.weights + ) + per_task_total.append( + compute_total_numb_batch( + training_data[model_key].index, + sampler_weights, + ) + ) + if self.num_epoch_dict: + ( + self.model_prob, + self.num_steps, + per_task_steps, + ) = resolve_model_prob_from_epochs( + self.model_keys, + self.num_epoch_dict, + np.asarray(per_task_total, dtype=np.float64), + ) + log.info( + "Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s " + "with per-task target steps: %s.", + self.model_prob, + self.num_steps, + self.num_epoch_dict, + {k: int(np.ceil(v)) for k, v in per_task_steps.items()}, + ) + else: + if self.num_steps is None: + raise ValueError( + "Either training.numb_steps (multi-task only) or " + "training.num_epoch_dict must be set." + ) + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + rank=self.rank, + ) + # Learning rate self.warmup_steps = training_params.get("warmup_steps", 0) self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) @@ -571,6 +648,15 @@ def single_model_finetune( frz_model = paddle.jit.load(init_frz_model) self.model.set_state_dict(frz_model.state_dict()) + # Get model prob for multi-task + if self.multi_task and self.model_prob is None: + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + rank=self.rank, + ) + # Multi-task share params if shared_links is not None: self.wrapper.share_params( @@ -678,21 +764,6 @@ def warm_up_linear(step, warmup_steps): ) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # Get model prob for multi-task - if self.multi_task: - self.model_prob = np.array([0.0 for key in self.model_keys]) - if training_params.get("model_prob", None) is not None: - model_prob = training_params["model_prob"] - for ii, model_key in enumerate(self.model_keys): - if model_key in model_prob: - self.model_prob[ii] += float(model_prob[model_key]) - else: - for ii, model_key in enumerate(self.model_keys): - self.model_prob[ii] += float(len(self.training_data[model_key])) - sum_prob = np.sum(self.model_prob) - assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" - self.model_prob = self.model_prob / sum_prob - # Tensorboard self.enable_tensorboard = training_params.get("tensorboard", False) self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0dfbe94b6b..32b0200e65 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -23,6 +23,11 @@ from deepmd.common import ( symlink_prefix_files, ) +from deepmd.dpmodel.utils import ( + compute_total_numb_batch, + resolve_model_prob, + resolve_model_prob_from_epochs, +) from deepmd.loggers.training import ( format_training_message, format_training_message_per_task, @@ -139,9 +144,12 @@ def __init__( else 1 ) self.num_model = len(self.model_keys) + self.model_prob = None # Iteration config - self.num_steps = training_params["numb_steps"] + self.num_steps = training_params.get("numb_steps") + self.num_epoch = training_params.get("num_epoch") + self.num_epoch_dict = training_params.get("num_epoch_dict") self.disp_file = training_params.get("disp_file", "lcurve.out") self.disp_freq = training_params.get("disp_freq", 1000) self.disp_avg = training_params.get("disp_avg", False) @@ -430,6 +438,74 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ), ) + # Resolve training steps + per_task_total = [] + if not self.multi_task: + sampler_weights = to_numpy_array(self.training_dataloader.sampler.weights) + total_numb_batch = compute_total_numb_batch( + training_data.index, + sampler_weights, + ) + if self.num_steps is None: + if self.num_epoch is None: + raise ValueError( + "Either training.numb_steps or training.num_epoch must be set." + ) + if self.num_epoch <= 0: + raise ValueError("training.num_epoch must be positive.") + if total_numb_batch <= 0: + raise ValueError( + "Total number of training batches must be positive." + ) + self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch)) + log.info( + "Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.", + self.num_steps, + self.num_epoch, + total_numb_batch, + ) + else: + for model_key in self.model_keys: + sampler_weights = to_numpy_array( + self.training_dataloader[model_key].sampler.weights + ) + per_task_total.append( + compute_total_numb_batch( + training_data[model_key].index, + sampler_weights, + ) + ) + if self.num_epoch_dict: + ( + self.model_prob, + self.num_steps, + per_task_steps, + ) = resolve_model_prob_from_epochs( + self.model_keys, + self.num_epoch_dict, + np.asarray(per_task_total, dtype=np.float64), + ) + log.info( + "Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s " + "with per-task target steps: %s.", + self.model_prob, + self.num_steps, + self.num_epoch_dict, + {k: int(np.ceil(v)) for k, v in per_task_steps.items()}, + ) + else: + if self.num_steps is None: + raise ValueError( + "Either training.numb_steps (multi-task only) or " + "training.num_epoch_dict must be set." + ) + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + rank=self.rank, + ) + # Learning rate warmup_steps = training_params.get("warmup_steps", None) warmup_ratio = training_params.get("warmup_ratio", None) @@ -653,19 +729,13 @@ def single_model_finetune( ) # Get model prob for multi-task - if self.multi_task: - self.model_prob = np.array([0.0 for key in self.model_keys]) - if training_params.get("model_prob", None) is not None: - model_prob = training_params["model_prob"] - for ii, model_key in enumerate(self.model_keys): - if model_key in model_prob: - self.model_prob[ii] += float(model_prob[model_key]) - else: - for ii, model_key in enumerate(self.model_keys): - self.model_prob[ii] += float(len(self.training_data[model_key])) - sum_prob = np.sum(self.model_prob) - assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" - self.model_prob = self.model_prob / sum_prob + if self.multi_task and self.model_prob is None: + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + rank=self.rank, + ) # Multi-task share params if shared_links is not None: diff --git a/deepmd/tf/entrypoints/change_bias.py b/deepmd/tf/entrypoints/change_bias.py index adf65c0e2b..ef85e1ab9d 100644 --- a/deepmd/tf/entrypoints/change_bias.py +++ b/deepmd/tf/entrypoints/change_bias.py @@ -15,6 +15,9 @@ expand_sys_str, j_loader, ) +from deepmd.dpmodel.utils import ( + compute_total_numb_batch, +) from deepmd.tf.entrypoints.freeze import ( freeze, ) @@ -187,7 +190,24 @@ def _change_bias_checkpoint_file( data = _load_data_systems(datafile, system, trainer) # Get stop_batch and origin_type_map like in train.py - stop_batch = jdata.get("training", {}).get("numb_steps", 0) + training_params = jdata.get("training", {}) + stop_batch = training_params.get("numb_steps") + num_epoch = training_params.get("num_epoch") + if stop_batch is None and num_epoch is not None: + if num_epoch <= 0: + raise ValueError("training.num_epoch must be positive.") + total_numb_batch = compute_total_numb_batch(data.nbatches, data.sys_probs) + if total_numb_batch <= 0: + raise ValueError("Total number of training batches must be positive.") + stop_batch = int(np.ceil(num_epoch * total_numb_batch)) + log.info( + "Computed numb_steps=%d from num_epoch=%s and total_numb_batch=%d.", + stop_batch, + num_epoch, + total_numb_batch, + ) + if stop_batch is None: + stop_batch = 0 origin_type_map = jdata["model"].get("origin_type_map", None) if origin_type_map is not None and not origin_type_map: # get the type_map from data if not provided diff --git a/deepmd/tf/entrypoints/train.py b/deepmd/tf/entrypoints/train.py index 3ab55e190c..e327947d9b 100755 --- a/deepmd/tf/entrypoints/train.py +++ b/deepmd/tf/entrypoints/train.py @@ -12,9 +12,14 @@ Any, ) +import numpy as np + from deepmd.common import ( j_loader, ) +from deepmd.dpmodel.utils import ( + compute_total_numb_batch, +) from deepmd.tf.env import ( GLOBAL_ENER_FLOAT_PRECISION, reset_default_tf_session_config, @@ -252,7 +257,32 @@ def _do_work( modifier.build_fv_graph() # get training info - stop_batch = jdata["training"]["numb_steps"] + training_params = jdata["training"] + stop_batch = training_params.get("numb_steps") + num_epoch = training_params.get("num_epoch") + if stop_batch is None: + if num_epoch is None: + raise ValueError( + "Either training.numb_steps or training.num_epoch must be set." + ) + if num_epoch <= 0: + raise ValueError("training.num_epoch must be positive.") + if train_data is None: + raise ValueError( + "training.num_epoch requires training data to compute total_numb_batch." + ) + total_numb_batch = compute_total_numb_batch( + train_data.nbatches, train_data.sys_probs + ) + if total_numb_batch <= 0: + raise ValueError("Total number of training batches must be positive.") + stop_batch = int(np.ceil(num_epoch * total_numb_batch)) + log.info( + "Computed numb_steps=%d from num_epoch=%s and total_numb_batch=%d.", + stop_batch, + num_epoch, + total_numb_batch, + ) origin_type_map = jdata["model"].get("origin_type_map", None) if ( origin_type_map is not None and not origin_type_map diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 935762cdc7..50faa611a5 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3213,7 +3213,35 @@ def mixed_precision_args() -> list[Argument]: # ! added by Denghui. def training_args( multi_task: bool = False, ) -> list[Argument]: # ! modified by Ziyao: data configuration isolated. - doc_numb_steps = "Number of training batch. Each training uses one batch of data." + doc_numb_steps = ( + "Number of training steps (num_step). Each training uses one batch of data. " + "Mutually exclusive with num_epoch in single-task mode. In multi-task " + "mode, this is mutually exclusive with num_epoch_dict. " + "Accepted names: num_step, num_steps, numb_step, numb_steps." + ) + doc_num_epoch = ( + "Number of training epochs (num_epoch; can be fractional) for single-task " + "mode only. Because each step samples the dataset stochastically, this " + "corresponds to an expected epoch count rather than a deterministic full " + "pass. When num_step is not set, the total steps are computed as " + "ceil(num_epoch * total_numb_batch). total_numb_batch is computed as " + "ceil(max_i(n_bch_i / p_i)), where n_bch_i is the number of batches for " + "system i and p_i is the sampling probability after sys_probs/auto_prob " + "normalization. Mutually exclusive with num_step. For multi-task mode, " + "use num_epoch_dict instead. Accepted names: num_epoch, num_epochs, " + "numb_epoch, numb_epochs." + ) + doc_num_epoch_dict = ( + "Number of training epochs for each model branch in multi-task mode " + "(can be fractional). This is a dictionary mapping model keys to the " + "number of epochs to train that specific model. When set, model_prob " + "is derived from the epoch targets and per-task total_numb_batch values: " + "model_prob[i] = num_epoch_dict[i] * per_task_total[i] / sum_j(num_epoch_dict[j] * per_task_total[j]). " + "Total training steps are computed as " + "ceil(sum_i(num_epoch_dict[i] * per_task_total[i])). " + "This parameter is mutually exclusive with training.model_prob and " + "training.num_step. All model keys must be specified in the dictionary." + ) doc_seed = "The random seed for getting frames from the training data set." doc_disp_file = "The file for printing learning curve." doc_disp_freq = "The frequency of printing learning curve." @@ -3270,7 +3298,12 @@ def training_args( ) doc_opt_type = "The type of optimizer to use." doc_kf_blocksize = "The blocksize for the Kalman filter." - doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." + doc_model_prob = ( + "The visiting probability of each model for each training step in the " + "multi-task mode. Only used when num_epoch_dict is not set. If not set " + "or an empty dict, defaults to weights proportional to the number of " + "systems per task." + ) doc_data_dict = "The multiple definition of the data, used in the multi-task mode." doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)." @@ -3290,6 +3323,13 @@ def training_args( if not multi_task else [ Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), + Argument( + "num_epoch_dict", + dict, + optional=True, + default={}, + doc=doc_num_epoch_dict, + ), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] ) @@ -3297,7 +3337,23 @@ def training_args( args += [ mixed_precision_data, Argument( - "numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"] + "numb_steps", + int, + optional=True, + doc=doc_numb_steps, + alias=[ + "stop_batch", + "num_step", + "num_steps", + "numb_step", + ], + ), + Argument( + "num_epoch", + [int, float], + optional=True, + doc=doc_num_epoch, + alias=["num_epochs", "numb_epoch", "numb_epochs"], ), Argument("seed", [int, None], optional=True, doc=doc_seed), Argument( @@ -3481,8 +3537,52 @@ def training_args( ) ] + def training_extra_check(data: dict | None) -> bool: + if data is None: + return True + num_steps = data.get("numb_steps") + num_epoch = data.get("num_epoch") + num_epoch_dict = data.get("num_epoch_dict", {}) + model_prob = data.get("model_prob", {}) + if multi_task: + if num_epoch is not None: + raise ValueError( + "training.num_epoch is only supported in single-task mode." + ) + if num_epoch_dict: + if num_steps is not None: + raise ValueError( + "training.num_epoch_dict is mutually exclusive with training.num_step." + ) + if model_prob: + raise ValueError( + "training.num_epoch_dict is mutually exclusive with training.model_prob." + ) + else: + if num_steps is None: + raise ValueError( + "Multi-task mode requires either training.numb_steps or training.num_epoch_dict." + ) + else: + if num_steps is not None and num_epoch is not None: + raise ValueError( + "training.num_step and training.num_epoch are mutually exclusive." + ) + if num_steps is None and num_epoch is None: + raise ValueError( + "Single-task mode requires either training.numb_steps or training.num_epoch." + ) + return True + doc_training = "The training options." - return Argument("training", dict, args, variants, doc=doc_training) + return Argument( + "training", + dict, + args, + variants, + doc=doc_training, + extra_check=training_extra_check, + ) def multi_model_args() -> list[Argument]: diff --git a/doc/train/multi-task-training.md b/doc/train/multi-task-training.md index 115c463cc2..2b420bb8a5 100644 --- a/doc/train/multi-task-training.md +++ b/doc/train/multi-task-training.md @@ -79,7 +79,15 @@ Specifically, there are several parts that need to be modified: - (Optional) {ref}`training/model_prob `: The sampling weight settings corresponding to each `model_key`, i.e., the probability weight in the training step. You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training. - This setting is optional, and if not set, tasks will be sampled with equal weights. + This setting is optional, and if not set, tasks will be sampled with equal weights. It is only used when `num_epoch_dict` is not set. + +- (Optional) {ref}`training/num_epoch_dict `: The number of training epochs for each model branch, specified as a dictionary mapping `model_key` to epoch values (can be fractional). + This allows different tasks to train for different numbers of epochs, which is particularly useful for multi-task fine-tuning scenarios + where a data-rich pretrained model is jointly trained with a data-scarce downstream task. + When set, `model_prob` is derived from the epoch targets and per-task totals: + `model_prob[i] = num_epoch_dict[i] * per_task_total[i] / sum_j(num_epoch_dict[j] * per_task_total[j])`. + The total training steps are computed as `ceil(sum_i(num_epoch_dict[i] * per_task_total[i]))`. + This parameter is mutually exclusive with `training/model_prob` and `training/num_steps`. An example input for multi-task training two models in water system is shown as following: diff --git a/source/tests/pt/test_sampler.py b/source/tests/pt/test_sampler.py index 3d7143b350..af0b5c1a98 100644 --- a/source/tests/pt/test_sampler.py +++ b/source/tests/pt/test_sampler.py @@ -7,15 +7,18 @@ ) import numpy as np +import pytest import torch from torch.utils.data import ( DataLoader, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, - get_sampler_from_params, - get_weighted_sampler, +import deepmd.pt.utils.dataloader as pt_dataloader +from deepmd.dpmodel.utils import ( + compute_total_numb_batch, +) +from deepmd.pt.utils import ( + dp_random, ) from deepmd.tf.common import ( expand_sys_str, @@ -28,8 +31,25 @@ CUR_DIR = os.path.dirname(__file__) +class _SerialPool: + def __init__(self, *args, **kwargs) -> None: + pass + + def __enter__(self) -> "_SerialPool": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def map(self, func, iterable): + return [func(item) for item in iterable] + + class TestSampler(unittest.TestCase): def setUp(self) -> None: + self._monkeypatch = pytest.MonkeyPatch() + # Avoid SemLock/CUDA initialization failures in restricted CI by forcing a serial pool. + self._monkeypatch.setattr(pt_dataloader, "Pool", _SerialPool) with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: content = fin.read() config = json.loads(content) @@ -40,14 +60,15 @@ def setUp(self) -> None: self.rcut = model_config["descriptor"]["rcut"] self.rcut_smth = model_config["descriptor"]["rcut_smth"] self.sel = model_config["descriptor"]["sel"] + self.type_map = model_config["type_map"] self.batch_size = config["training"]["training_data"]["batch_size"] self.systems = config["training"]["validation_data"]["systems"] if isinstance(self.systems, str): self.systems = expand_sys_str(self.systems) - self.my_dataset = DpLoaderSet( + self.my_dataset = pt_dataloader.DpLoaderSet( self.systems, self.batch_size, - model_config["type_map"], + self.type_map, seed=10, shuffle=False, ) @@ -55,10 +76,86 @@ def setUp(self) -> None: tf_random.seed(10) self.dp_dataset = DeepmdDataSystem(self.systems, self.batch_size, 1, self.rcut) + def tearDown(self) -> None: + self._monkeypatch.undo() + + def _make_dataloader( + self, dataset: pt_dataloader.DpLoaderSet, sampler + ) -> DataLoader: + return DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=0, + drop_last=False, + collate_fn=lambda batch: batch, + ) + + def _normalize_probs(self, weights: np.ndarray) -> np.ndarray: + weights = np.asarray(weights, dtype=np.float64) + return weights / np.sum(weights) + + def _sample_sid_counts( + self, dataloader: DataLoader, num_steps: int, nsystems: int + ) -> np.ndarray: + # === Step 1. Initialize Counters === + counts = np.zeros(nsystems, dtype=np.int64) + # === Step 2. Sample Steps === + with torch.device("cpu"): + iterator = iter(dataloader) + for _ in range(num_steps): + try: + batch_data = next(iterator) + except StopIteration: + iterator = iter(dataloader) + batch_data = next(iterator) + sid = batch_data["sid"] + if hasattr(sid, "item"): + sid = sid.item() + counts[int(sid)] += 1 + return counts + + def _sample_multitask_counts( + self, + dataloaders: dict[str, DataLoader], + model_prob: np.ndarray, + num_steps: int, + ) -> tuple[np.ndarray, dict[str, np.ndarray]]: + # === Step 1. Initialize Counters === + model_keys = list(dataloaders.keys()) + model_counts = np.zeros(len(model_keys), dtype=np.int64) + sid_counts = { + model_key: np.zeros(len(dataloaders[model_key].dataset), dtype=np.int64) + for model_key in model_keys + } + # === Step 2. Build Iterators and Sample Steps === + with torch.device("cpu"): + iters = { + model_key: iter(dataloaders[model_key]) for model_key in model_keys + } + for _ in range(num_steps): + model_index = dp_random.choice( + np.arange(len(model_keys), dtype=np.int_), p=model_prob + ) + model_key = model_keys[int(model_index)] + model_counts[int(model_index)] += 1 + try: + batch_data = next(iters[model_key]) + except StopIteration: + iters[model_key] = iter(dataloaders[model_key]) + batch_data = next(iters[model_key]) + sid = batch_data["sid"] + if hasattr(sid, "item"): + sid = sid.item() + sid_counts[model_key][int(sid)] += 1 + return model_counts, sid_counts + def test_sampler_debug_info(self) -> None: dataloader = DataLoader( self.my_dataset, - sampler=get_weighted_sampler(self.my_dataset, prob_style="prob_sys_size"), + sampler=pt_dataloader.get_weighted_sampler( + self.my_dataset, prob_style="prob_sys_size" + ), batch_size=None, num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1 drop_last=False, @@ -73,7 +170,9 @@ def test_sampler_debug_info(self) -> None: def test_auto_prob_uniform(self) -> None: auto_prob_style = "prob_uniform" - sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + sampler = pt_dataloader.get_weighted_sampler( + self.my_dataset, prob_style=auto_prob_style + ) my_probs = np.array(sampler.weights) self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) dp_probs = np.array(self.dp_dataset.sys_probs) @@ -81,7 +180,9 @@ def test_auto_prob_uniform(self) -> None: def test_auto_prob_sys_size(self) -> None: auto_prob_style = "prob_sys_size" - sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + sampler = pt_dataloader.get_weighted_sampler( + self.my_dataset, prob_style=auto_prob_style + ) my_probs = np.array(sampler.weights) self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) dp_probs = np.array(self.dp_dataset.sys_probs) @@ -89,7 +190,9 @@ def test_auto_prob_sys_size(self) -> None: def test_auto_prob_sys_size_ext(self) -> None: auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8" - sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + sampler = pt_dataloader.get_weighted_sampler( + self.my_dataset, prob_style=auto_prob_style + ) my_probs = np.array(sampler.weights) self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) dp_probs = np.array(self.dp_dataset.sys_probs) @@ -97,7 +200,7 @@ def test_auto_prob_sys_size_ext(self) -> None: def test_sys_probs(self) -> None: sys_probs = [0.1, 0.4, 0.5] - sampler = get_weighted_sampler( + sampler = pt_dataloader.get_weighted_sampler( self.my_dataset, prob_style=sys_probs, sys_prob=True ) my_probs = np.array(sampler.weights) @@ -111,7 +214,7 @@ def test_sys_probs_end2end(self): "sys_probs": sys_probs, "auto_prob": "prob_sys_size", } # use sys_probs first - sampler = get_sampler_from_params(self.my_dataset, _params) + sampler = pt_dataloader.get_sampler_from_params(self.my_dataset, _params) my_probs = np.array(sampler.weights) self.dp_dataset.set_sys_probs(sys_probs=sys_probs) dp_probs = np.array(self.dp_dataset.sys_probs) @@ -120,12 +223,293 @@ def test_sys_probs_end2end(self): def test_auto_prob_sys_size_ext_end2end(self): auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8" _params = {"sys_probs": None, "auto_prob": auto_prob_style} # use auto_prob - sampler = get_sampler_from_params(self.my_dataset, _params) + sampler = pt_dataloader.get_sampler_from_params(self.my_dataset, _params) my_probs = np.array(sampler.weights) self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) dp_probs = np.array(self.dp_dataset.sys_probs) self.assertTrue(np.allclose(my_probs, dp_probs)) + def test_sampling_stability_single_task(self) -> None: + # === Step 1. Build Dataset and Sampler === + systems = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + str(Path(__file__).parent / "water/data/single"), + ] + dataset_epoch = pt_dataloader.DpLoaderSet( + systems, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + sys_probs = [0.2, 0.3, 0.5] + params = {"sys_probs": sys_probs, "auto_prob": "prob_sys_size"} + sampler_epoch = pt_dataloader.get_sampler_from_params(dataset_epoch, params) + nbatches = np.asarray(dataset_epoch.index, dtype=np.float64) + total_numb_batch = compute_total_numb_batch( + nbatches, np.asarray(sampler_epoch.weights) + ) + num_epoch = 1.5 + num_steps = int(np.ceil(num_epoch * total_numb_batch)) + probs = self._normalize_probs(np.asarray(sampler_epoch.weights)) + + # === Step 2. Sample Using Derived Steps === + torch.manual_seed(123) + dataloader_epoch = self._make_dataloader(dataset_epoch, sampler_epoch) + counts_epoch = self._sample_sid_counts( + dataloader_epoch, num_steps, len(dataset_epoch) + ) + empirical_epoch = counts_epoch / float(num_steps) + self.assertTrue(np.allclose(empirical_epoch, probs, atol=0.1)) + + # === Step 3. Sample Using Explicit Steps === + dataset_steps = pt_dataloader.DpLoaderSet( + systems, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + sampler_steps = pt_dataloader.get_sampler_from_params(dataset_steps, params) + torch.manual_seed(123) + dataloader_steps = self._make_dataloader(dataset_steps, sampler_steps) + counts_steps = self._sample_sid_counts( + dataloader_steps, num_steps, len(dataset_steps) + ) + self.assertTrue(np.array_equal(counts_epoch, counts_steps)) + + def test_sampling_stability_multi_task(self) -> None: + # === Step 1. Build Datasets and Samplers === + model_keys = ["model_1", "model_2"] + systems_1 = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + ] + systems_2 = [ + str(Path(__file__).parent / "water/data/data_1"), + str(Path(__file__).parent / "water/data/single"), + ] + dataset_1 = pt_dataloader.DpLoaderSet( + systems_1, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + dataset_2 = pt_dataloader.DpLoaderSet( + systems_2, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + sampler_1 = pt_dataloader.get_sampler_from_params( + dataset_1, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"} + ) + sampler_2 = pt_dataloader.get_sampler_from_params( + dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"} + ) + probs_1 = self._normalize_probs(np.asarray(sampler_1.weights)) + probs_2 = self._normalize_probs(np.asarray(sampler_2.weights)) + per_task_total = np.array( + [ + compute_total_numb_batch( + np.asarray(dataset_1.index, dtype=np.float64), + np.asarray(sampler_1.weights), + ), + compute_total_numb_batch( + np.asarray(dataset_2.index, dtype=np.float64), + np.asarray(sampler_2.weights), + ), + ], + dtype=np.float64, + ) + num_epoch_dict = {model_keys[0]: 1.5, model_keys[1]: 0.8} + target_steps = np.array( + [ + num_epoch_dict[model_keys[0]] * per_task_total[0], + num_epoch_dict[model_keys[1]] * per_task_total[1], + ], + dtype=np.float64, + ) + total_target_steps = float(np.sum(target_steps)) + model_prob = target_steps / total_target_steps + num_steps = int(np.ceil(total_target_steps)) + + # === Step 2. Sample Using Derived Steps === + dataloaders_epoch = { + model_keys[0]: self._make_dataloader(dataset_1, sampler_1), + model_keys[1]: self._make_dataloader(dataset_2, sampler_2), + } + dp_random.seed(321) + torch.manual_seed(321) + model_counts_epoch, sid_counts_epoch = self._sample_multitask_counts( + dataloaders_epoch, model_prob, num_steps + ) + model_freq_epoch = model_counts_epoch / float(num_steps) + self.assertTrue(np.allclose(model_freq_epoch, model_prob, atol=0.1)) + if model_counts_epoch[0] == 0 or model_counts_epoch[1] == 0: + raise AssertionError("Model sampling produced zero counts for a task.") + self.assertTrue( + np.allclose( + sid_counts_epoch[model_keys[0]] / model_counts_epoch[0], + probs_1, + atol=0.1, + ) + ) + self.assertTrue( + np.allclose( + sid_counts_epoch[model_keys[1]] / model_counts_epoch[1], + probs_2, + atol=0.1, + ) + ) + + # === Step 3. Sample Using Explicit Steps === + dataset_1b = pt_dataloader.DpLoaderSet( + systems_1, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + dataset_2b = pt_dataloader.DpLoaderSet( + systems_2, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + sampler_1b = pt_dataloader.get_sampler_from_params( + dataset_1b, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"} + ) + sampler_2b = pt_dataloader.get_sampler_from_params( + dataset_2b, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"} + ) + dataloaders_steps = { + model_keys[0]: self._make_dataloader(dataset_1b, sampler_1b), + model_keys[1]: self._make_dataloader(dataset_2b, sampler_2b), + } + dp_random.seed(321) + torch.manual_seed(321) + model_counts_steps, sid_counts_steps = self._sample_multitask_counts( + dataloaders_steps, model_prob, num_steps + ) + self.assertTrue(np.array_equal(model_counts_epoch, model_counts_steps)) + self.assertTrue( + np.array_equal( + sid_counts_epoch[model_keys[0]], sid_counts_steps[model_keys[0]] + ) + ) + self.assertTrue( + np.array_equal( + sid_counts_epoch[model_keys[1]], sid_counts_steps[model_keys[1]] + ) + ) + + def test_num_epoch_dict(self) -> None: + """Test num_epoch_dict calculation logic for multi-task training.""" + # === Step 1. Build Datasets === + model_keys = ["model_1", "model_2"] + systems_1 = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + ] + systems_2 = [ + str(Path(__file__).parent / "water/data/data_1"), + str(Path(__file__).parent / "water/data/single"), + ] + dataset_1 = pt_dataloader.DpLoaderSet( + systems_1, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + dataset_2 = pt_dataloader.DpLoaderSet( + systems_2, + self.batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + sampler_1 = pt_dataloader.get_sampler_from_params( + dataset_1, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"} + ) + sampler_2 = pt_dataloader.get_sampler_from_params( + dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"} + ) + + # === Step 2. Compute per-task total_numb_batch === + per_task_total = np.array( + [ + compute_total_numb_batch( + np.asarray(dataset_1.index, dtype=np.float64), + np.asarray(sampler_1.weights), + ), + compute_total_numb_batch( + np.asarray(dataset_2.index, dtype=np.float64), + np.asarray(sampler_2.weights), + ), + ], + dtype=np.float64, + ) + + # === Step 3. Test num_epoch_dict calculation === + num_epoch_dict = {model_keys[0]: 2.0, model_keys[1]: 5.0} + + # Compute expected steps and model_prob from epoch targets + per_task_steps = np.array( + [ + num_epoch_dict[model_keys[0]] * per_task_total[0], + num_epoch_dict[model_keys[1]] * per_task_total[1], + ], + dtype=np.float64, + ) + total_target_steps = float(np.sum(per_task_steps)) + model_prob = per_task_steps / total_target_steps + expected_num_steps = int(np.ceil(total_target_steps)) + + # Verify the calculation matches the expected formula + self.assertIsInstance(expected_num_steps, int) + self.assertGreater(expected_num_steps, 0) + + # Verify that running expected_num_steps would give each task at least + # its target epochs (may be more for tasks needing fewer steps) + expected_model_0_counts = expected_num_steps * model_prob[0] + expected_model_1_counts = expected_num_steps * model_prob[1] + + # Each task should complete at least its target epochs + expected_epochs_0 = expected_model_0_counts / per_task_total[0] + expected_epochs_1 = expected_model_1_counts / per_task_total[1] + + self.assertGreaterEqual( + expected_epochs_0, + num_epoch_dict[model_keys[0]], + msg="Model 0 should complete at least 2 epochs", + ) + self.assertGreaterEqual( + expected_epochs_1, + num_epoch_dict[model_keys[1]], + msg="Model 1 should complete at least 5 epochs", + ) + + # All tasks should be scaled by the same rounding factor. + scale_0 = expected_epochs_0 / num_epoch_dict[model_keys[0]] + scale_1 = expected_epochs_1 / num_epoch_dict[model_keys[1]] + self.assertGreaterEqual( + scale_0, + 1.0, + msg="Rounding should not reduce expected epochs.", + ) + self.assertAlmostEqual( + scale_0, + scale_1, + delta=0.1, + msg="Rounding should scale all tasks consistently.", + ) + if __name__ == "__main__": unittest.main()