Skip to content

Commit 96f851d

Browse files
Refactor Sampler: move selectSampler() from LlamaApp to Sampler class
1 parent 2b622d1 commit 96f851d

File tree

2 files changed

+99
-86
lines changed

2 files changed

+99
-86
lines changed

src/main/java/org/beehive/gpullama3/LlamaApp.java

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,98 +15,13 @@
1515
import java.util.random.RandomGenerator;
1616
import java.util.random.RandomGeneratorFactory;
1717

18+
import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler;
1819
public class LlamaApp {
1920
// Configuration flags for hardware acceleration and optimizations
2021
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
2122
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
2223
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
2324

24-
/**
25-
* Creates and configures a sampler for token generation based on specified parameters.
26-
*
27-
* <p>This method selects an appropriate sampling strategy for next-token prediction
28-
* in language model inference. It supports several sampling approaches:</p>
29-
*
30-
* <ul>
31-
* <li>Greedy sampling (temperature = 0): Always selects the most probable token</li>
32-
* <li>Temperature sampling: Adjusts probability distribution sharpness</li>
33-
* <li>Top-p (nucleus) sampling: Considers only tokens comprising the top p probability mass</li>
34-
* </ul>
35-
*
36-
* <p>The method handles both {@link FloatTensor} and {@link FloatArray} logits types
37-
* to support both CPU and GPU execution paths.</p>
38-
*
39-
* @param vocabularySize
40-
* The size of the model's vocabulary
41-
* @param temperature
42-
* A value controlling randomness in sampling:
43-
* <ul>
44-
* <li>0.0f: No randomness (greedy sampling)</li>
45-
* <li>1.0f: Standard sampling from unmodified distribution</li>
46-
* <li>&lt;1.0f: More deterministic (sharper distribution)</li>
47-
* <li>&gt;1.0f: More random (flatter distribution)</li>
48-
* </ul>
49-
* @param topp
50-
* The cumulative probability threshold for nucleus sampling (0.0-1.0).
51-
* <ul>
52-
* <li>Values ≤0 or ≥1: Disables top-p sampling</li>
53-
* <li>Values in (0,1): Restricts sampling to tokens comprising the top p probability mass</li>
54-
* </ul>
55-
* @param rngSeed
56-
* Seed value for the random number generator to ensure reproducibility
57-
* @return A configured {@link Sampler} that implements the selected sampling strategy and handles both tensor and array-based logits
58-
* @throws IllegalArgumentException
59-
* if logits are of an unsupported type
60-
*/
61-
public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
62-
Sampler sampler;
63-
if (temperature == 0.0f) {
64-
// greedy argmax sampling: take the token with the highest probability
65-
sampler = Sampler.TENSOR_ARGMAX; // Use TENSOR_ARGMAX instead of ARGMAX
66-
} else {
67-
// we sample from this distribution to get the next token
68-
RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
69-
Sampler innerSampler;
70-
// Determine whether to use top-p (nucleus) sampling
71-
if (topp <= 0 || topp >= 1) {
72-
// If topp is outside (0,1), use standard categorical sampling
73-
// This samples directly from the probability distribution
74-
innerSampler = new CategoricalSampler(rng);
75-
} else {
76-
// Use top-p (nucleus) sampling with the specified threshold
77-
// This restricts sampling to only the most likely tokens that
78-
// cumulatively comprise the top p probability mass
79-
innerSampler = new ToppSampler(vocabularySize, topp, rng);
80-
}
81-
82-
// Create a sampler that:
83-
// 1. Applies temperature scaling to the logits
84-
// 2. Converts logits to probabilities using softmax
85-
// 3. Delegates the actual sampling to the appropriate inner sampler
86-
sampler = logits -> {
87-
// Handle different logits formats to support both CPU and GPU paths
88-
if (logits instanceof FloatTensor) {
89-
// For CPU path using FloatTensor
90-
FloatTensor tensorLogits = (FloatTensor) logits;
91-
// Apply temperature scaling - lower values make distribution more peaked
92-
tensorLogits.divideInPlace(0, tensorLogits.size(), temperature);
93-
// Convert logits to probabilities using softmax
94-
tensorLogits.softmaxInPlace(0, tensorLogits.size());
95-
} else if (logits instanceof FloatArray) {
96-
// For GPU path using FloatArray
97-
FloatArray arrayLogits = (FloatArray) logits;
98-
// Apply the same operations but using FloatArray-specific methods for TornadoVM data types
99-
FloatArrayUtils.divideInPlace(arrayLogits, 0, arrayLogits.getSize(), temperature);
100-
FloatArrayUtils.softmaxInPlace(arrayLogits, 0, arrayLogits.getSize());
101-
} else {
102-
// If logits are neither FloatTensor nor FloatArray, throw an exception
103-
throw new IllegalArgumentException("Unsupported logits type: " + (logits != null ? logits.getClass().getName() : "null"));
104-
}
105-
return innerSampler.sampleToken(logits);
106-
};
107-
}
108-
return sampler;
109-
}
11025

11126
/**
11227
* Loads the language model based on the given options.

src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,112 @@
11
package org.beehive.gpullama3.inference.sampler;
22

3+
import org.beehive.gpullama3.Options;
34
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
5+
import org.beehive.gpullama3.model.Model;
6+
import org.beehive.gpullama3.tornadovm.FloatArrayUtils;
47
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
58

9+
import java.util.random.RandomGenerator;
10+
import java.util.random.RandomGeneratorFactory;
11+
612
/**
713
* Generic interface for sampling tokens from probability distributions.
814
* Supports both FloatTensor and FloatArray tensor implementations.
915
*/
1016
@FunctionalInterface
1117
public interface Sampler {
18+
19+
/**
20+
* Creates and configures a sampler for token generation based on specified parameters.
21+
*
22+
* <p>This method selects an appropriate sampling strategy for next-token prediction
23+
* in language model inference. It supports several sampling approaches:</p>
24+
*
25+
* <ul>
26+
* <li>Greedy sampling (temperature = 0): Always selects the most probable token</li>
27+
* <li>Temperature sampling: Adjusts probability distribution sharpness</li>
28+
* <li>Top-p (nucleus) sampling: Considers only tokens comprising the top p probability mass</li>
29+
* </ul>
30+
*
31+
* <p>The method handles both {@link FloatTensor} and {@link FloatArray} logits types
32+
* to support both CPU and GPU execution paths.</p>
33+
*
34+
* @param vocabularySize
35+
* The size of the model's vocabulary
36+
* @param temperature
37+
* A value controlling randomness in sampling:
38+
* <ul>
39+
* <li>0.0f: No randomness (greedy sampling)</li>
40+
* <li>1.0f: Standard sampling from unmodified distribution</li>
41+
* <li>&lt;1.0f: More deterministic (sharper distribution)</li>
42+
* <li>&gt;1.0f: More random (flatter distribution)</li>
43+
* </ul>
44+
* @param topp
45+
* The cumulative probability threshold for nucleus sampling (0.0-1.0).
46+
* <ul>
47+
* <li>Values ≤0 or ≥1: Disables top-p sampling</li>
48+
* <li>Values in (0,1): Restricts sampling to tokens comprising the top p probability mass</li>
49+
* </ul>
50+
* @param rngSeed
51+
* Seed value for the random number generator to ensure reproducibility
52+
* @return A configured {@link Sampler} that implements the selected sampling strategy and handles both tensor and array-based logits
53+
* @throws IllegalArgumentException
54+
* if logits are of an unsupported type
55+
*/
56+
static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
57+
Sampler sampler;
58+
if (temperature == 0.0f) {
59+
// greedy argmax sampling: take the token with the highest probability
60+
sampler = Sampler.TENSOR_ARGMAX; // Use TENSOR_ARGMAX instead of ARGMAX
61+
} else {
62+
// we sample from this distribution to get the next token
63+
RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
64+
Sampler innerSampler;
65+
// Determine whether to use top-p (nucleus) sampling
66+
if (topp <= 0 || topp >= 1) {
67+
// If topp is outside (0,1), use standard categorical sampling
68+
// This samples directly from the probability distribution
69+
innerSampler = new CategoricalSampler(rng);
70+
} else {
71+
// Use top-p (nucleus) sampling with the specified threshold
72+
// This restricts sampling to only the most likely tokens that
73+
// cumulatively comprise the top p probability mass
74+
innerSampler = new ToppSampler(vocabularySize, topp, rng);
75+
}
76+
77+
// Create a sampler that:
78+
// 1. Applies temperature scaling to the logits
79+
// 2. Converts logits to probabilities using softmax
80+
// 3. Delegates the actual sampling to the appropriate inner sampler
81+
sampler = logits -> {
82+
// Handle different logits formats to support both CPU and GPU paths
83+
if (logits instanceof FloatTensor) {
84+
// For CPU path using FloatTensor
85+
FloatTensor tensorLogits = (FloatTensor) logits;
86+
// Apply temperature scaling - lower values make distribution more peaked
87+
tensorLogits.divideInPlace(0, tensorLogits.size(), temperature);
88+
// Convert logits to probabilities using softmax
89+
tensorLogits.softmaxInPlace(0, tensorLogits.size());
90+
} else if (logits instanceof FloatArray) {
91+
// For GPU path using FloatArray
92+
FloatArray arrayLogits = (FloatArray) logits;
93+
// Apply the same operations but using FloatArray-specific methods for TornadoVM data types
94+
FloatArrayUtils.divideInPlace(arrayLogits, 0, arrayLogits.getSize(), temperature);
95+
FloatArrayUtils.softmaxInPlace(arrayLogits, 0, arrayLogits.getSize());
96+
} else {
97+
// If logits are neither FloatTensor nor FloatArray, throw an exception
98+
throw new IllegalArgumentException("Unsupported logits type: " + (logits != null ? logits.getClass().getName() : "null"));
99+
}
100+
return innerSampler.sampleToken(logits);
101+
};
102+
}
103+
return sampler;
104+
}
105+
106+
static Sampler createSampler(Model model, Options options) {
107+
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
108+
}
109+
12110
/**
13111
* Sample a token from the provided tensor.
14112
*

0 commit comments

Comments
 (0)