diff --git a/README.md b/README.md index 8c615ae..48195d0 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,22 @@ Assuming you have [miniconda](https://docs.conda.io/en/latest/miniconda.html) in ### Training The entrypoint `train` is the main driver for training and accepts parameters using Hydra syntax. -The available parameters for configuration can be found by running `train` --help or by looking in the `src/walkjump/hydra_config` directory +The available parameters for configuration can be found by running `train` --help (```walkjump_train --help```) or by looking in the `src/walkjump/hydra_config` directory ### Sampling The entrypoint `sample` is the main driver for training and accepts parameters using Hydra syntax. -The available parameters for configuration can be found by running `sample` --help or by looking in the `src/walkjump/hydra_config` directory +The available parameters for configuration can be found by running `sample` --help (```walkjump_sample --help```) or by looking in the `src/walkjump/hydra_config` directory + +### Example +```bash +conda activate wj +walkjump_train data.csv_data_path="data/poas.csv.gz" +``` +then +```bash +walkjump_sample 'model.checkpoint_path="checkpoints/epoch=17-step=363937-val_loss=0.0040.ckpt"' designs.output_csv=my_samples.csv +``` +(Extra quotation marks to handle "=" in file path) ## Contributing diff --git a/src/walkjump/cmdline/_sample.py b/src/walkjump/cmdline/_sample.py index 9ec403e..401ad88 100644 --- a/src/walkjump/cmdline/_sample.py +++ b/src/walkjump/cmdline/_sample.py @@ -4,7 +4,7 @@ from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf -from walkjump.cmdline.utils import instantiate_redesign_mask, instantiate_seeds +from walkjump.cmdline.utils import instantiate_redesign_mask, instantiate_seeds, instantiate_model_for_sample_mode from walkjump.sampling import walkjump @@ -28,7 +28,7 @@ def sample(cfg: DictConfig) -> bool: seeds = instantiate_seeds(cfg.designs) if not cfg.dryrun: - model = hydra.utils.instantiate(cfg.model).to(device) + model = instantiate_model_for_sample_mode(cfg.model).to(device) sample_df = walkjump( seeds, model, diff --git a/src/walkjump/hydra_config/sample.yaml b/src/walkjump/hydra_config/sample.yaml index e8c4dcf..2459b13 100644 --- a/src/walkjump/hydra_config/sample.yaml +++ b/src/walkjump/hydra_config/sample.yaml @@ -3,18 +3,17 @@ defaults: - setup: default model: - _target_: walkjump.cmdline.utils.instantiate_model_for_sample_mode model_type: denoise checkpoint_path: ??? denoise_path: null langevin: - sigma: 1.0 - delta: 0.5 - lipschitz: 1.0 - friction: 1.0 - steps: 20 - chunksize: 8 + sigma: 1.0 # Noise level + delta: 0.5 # Step size + lipschitz: 1.0 # Lipschitz constant, related to mass: u = pow(lipschitz, -1) + friction: 1.0 # (Gamma) Dampening term + steps: 20 # (K) Number of steps in chain + chunksize: 8 # Used for chunking the batch to save memory. Providing chunksize = N will force the sampling to occur in N batches. designs: output_csv: samples.csv @@ -22,5 +21,7 @@ designs: seeds: denovo num_samples: 100 limit_seeds: 10 + chunksize: 8 -device: null \ No newline at end of file +device: null +dryrun: false \ No newline at end of file