diff --git a/doc/flags.md b/doc/flags.md index f2652bbf..41581ee4 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -57,3 +57,4 @@ The sampling script `sample.lua` accepts the following command-line flags: - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode. - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information. +- `-seed`: Default is 0. Set this to a non-zero value to seed the torch RNG with a specific value. Omit flag or set as zero to randomly seed the RNG(!). Seeding the RNG allows for reproductible output over multiple runs of sample.lua given the same parameters and checkpoint. Honours the -verbose flag. diff --git a/sample.lua b/sample.lua index 4e6ebae0..097cd5df 100644 --- a/sample.lua +++ b/sample.lua @@ -13,12 +13,23 @@ cmd:option('-temperature', 1) cmd:option('-gpu', 0) cmd:option('-gpu_backend', 'cuda') cmd:option('-verbose', 0) +cmd:option('-seed', 0) local opt = cmd:parse(arg) local checkpoint = torch.load(opt.checkpoint) local model = checkpoint.model +if opt.seed == 0 then + opt.seed = torch.random() +end +torch.manualSeed(opt.seed) + +local msg +msg = string.format('Random number seed: %d', opt.seed) +if opt.verbose == 1 then print(msg) end + + local msg if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then require 'cutorch'