diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2e93670e6..2124fcdaa 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -304,7 +304,11 @@ def __init__( self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() # Used by the sampler - self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED + if seed == -1: + # set a random seed + self._seed = random.randint(0, 2 ** 32) + else: + self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED # Context Params self.context_params = llama_cpp.llama_context_default_params()