|
1 | 1 | package org.beehive.gpullama3.inference.sampler;
|
2 | 2 |
|
| 3 | +import org.beehive.gpullama3.Options; |
3 | 4 | import org.beehive.gpullama3.core.model.tensor.FloatTensor;
|
| 5 | +import org.beehive.gpullama3.model.Model; |
| 6 | +import org.beehive.gpullama3.tornadovm.FloatArrayUtils; |
4 | 7 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
|
5 | 8 |
|
| 9 | +import java.util.random.RandomGenerator; |
| 10 | +import java.util.random.RandomGeneratorFactory; |
| 11 | + |
6 | 12 | /**
|
7 | 13 | * Generic interface for sampling tokens from probability distributions.
|
8 | 14 | * Supports both FloatTensor and FloatArray tensor implementations.
|
9 | 15 | */
|
10 | 16 | @FunctionalInterface
|
11 | 17 | public interface Sampler {
|
| 18 | + |
| 19 | + /** |
| 20 | + * Creates and configures a sampler for token generation based on specified parameters. |
| 21 | + * |
| 22 | + * <p>This method selects an appropriate sampling strategy for next-token prediction |
| 23 | + * in language model inference. It supports several sampling approaches:</p> |
| 24 | + * |
| 25 | + * <ul> |
| 26 | + * <li>Greedy sampling (temperature = 0): Always selects the most probable token</li> |
| 27 | + * <li>Temperature sampling: Adjusts probability distribution sharpness</li> |
| 28 | + * <li>Top-p (nucleus) sampling: Considers only tokens comprising the top p probability mass</li> |
| 29 | + * </ul> |
| 30 | + * |
| 31 | + * <p>The method handles both {@link FloatTensor} and {@link FloatArray} logits types |
| 32 | + * to support both CPU and GPU execution paths.</p> |
| 33 | + * |
| 34 | + * @param vocabularySize |
| 35 | + * The size of the model's vocabulary |
| 36 | + * @param temperature |
| 37 | + * A value controlling randomness in sampling: |
| 38 | + * <ul> |
| 39 | + * <li>0.0f: No randomness (greedy sampling)</li> |
| 40 | + * <li>1.0f: Standard sampling from unmodified distribution</li> |
| 41 | + * <li><1.0f: More deterministic (sharper distribution)</li> |
| 42 | + * <li>>1.0f: More random (flatter distribution)</li> |
| 43 | + * </ul> |
| 44 | + * @param topp |
| 45 | + * The cumulative probability threshold for nucleus sampling (0.0-1.0). |
| 46 | + * <ul> |
| 47 | + * <li>Values ≤0 or ≥1: Disables top-p sampling</li> |
| 48 | + * <li>Values in (0,1): Restricts sampling to tokens comprising the top p probability mass</li> |
| 49 | + * </ul> |
| 50 | + * @param rngSeed |
| 51 | + * Seed value for the random number generator to ensure reproducibility |
| 52 | + * @return A configured {@link Sampler} that implements the selected sampling strategy and handles both tensor and array-based logits |
| 53 | + * @throws IllegalArgumentException |
| 54 | + * if logits are of an unsupported type |
| 55 | + */ |
| 56 | + static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) { |
| 57 | + Sampler sampler; |
| 58 | + if (temperature == 0.0f) { |
| 59 | + // greedy argmax sampling: take the token with the highest probability |
| 60 | + sampler = Sampler.TENSOR_ARGMAX; // Use TENSOR_ARGMAX instead of ARGMAX |
| 61 | + } else { |
| 62 | + // we sample from this distribution to get the next token |
| 63 | + RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed); |
| 64 | + Sampler innerSampler; |
| 65 | + // Determine whether to use top-p (nucleus) sampling |
| 66 | + if (topp <= 0 || topp >= 1) { |
| 67 | + // If topp is outside (0,1), use standard categorical sampling |
| 68 | + // This samples directly from the probability distribution |
| 69 | + innerSampler = new CategoricalSampler(rng); |
| 70 | + } else { |
| 71 | + // Use top-p (nucleus) sampling with the specified threshold |
| 72 | + // This restricts sampling to only the most likely tokens that |
| 73 | + // cumulatively comprise the top p probability mass |
| 74 | + innerSampler = new ToppSampler(vocabularySize, topp, rng); |
| 75 | + } |
| 76 | + |
| 77 | + // Create a sampler that: |
| 78 | + // 1. Applies temperature scaling to the logits |
| 79 | + // 2. Converts logits to probabilities using softmax |
| 80 | + // 3. Delegates the actual sampling to the appropriate inner sampler |
| 81 | + sampler = logits -> { |
| 82 | + // Handle different logits formats to support both CPU and GPU paths |
| 83 | + if (logits instanceof FloatTensor) { |
| 84 | + // For CPU path using FloatTensor |
| 85 | + FloatTensor tensorLogits = (FloatTensor) logits; |
| 86 | + // Apply temperature scaling - lower values make distribution more peaked |
| 87 | + tensorLogits.divideInPlace(0, tensorLogits.size(), temperature); |
| 88 | + // Convert logits to probabilities using softmax |
| 89 | + tensorLogits.softmaxInPlace(0, tensorLogits.size()); |
| 90 | + } else if (logits instanceof FloatArray) { |
| 91 | + // For GPU path using FloatArray |
| 92 | + FloatArray arrayLogits = (FloatArray) logits; |
| 93 | + // Apply the same operations but using FloatArray-specific methods for TornadoVM data types |
| 94 | + FloatArrayUtils.divideInPlace(arrayLogits, 0, arrayLogits.getSize(), temperature); |
| 95 | + FloatArrayUtils.softmaxInPlace(arrayLogits, 0, arrayLogits.getSize()); |
| 96 | + } else { |
| 97 | + // If logits are neither FloatTensor nor FloatArray, throw an exception |
| 98 | + throw new IllegalArgumentException("Unsupported logits type: " + (logits != null ? logits.getClass().getName() : "null")); |
| 99 | + } |
| 100 | + return innerSampler.sampleToken(logits); |
| 101 | + }; |
| 102 | + } |
| 103 | + return sampler; |
| 104 | + } |
| 105 | + |
| 106 | + static Sampler createSampler(Model model, Options options) { |
| 107 | + return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
| 108 | + } |
| 109 | + |
12 | 110 | /**
|
13 | 111 | * Sample a token from the provided tensor.
|
14 | 112 | *
|
|
0 commit comments