Skip to content

Commit 04ba434

Browse files
Refactor and improve code formatting across multiple files
Applied consistent formatting using @Formatter directives to enhance readability. Improved class documentation with detailed JavaDoc comments for methods and constructors, clarifying their purpose and parameters. Adjusted code style for multiline constructs and added missing comments where necessary.
1 parent dabbdfb commit 04ba434

26 files changed

+366
-124
lines changed

src/main/java/com/example/aot/AOT.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ public final class AOT {
3232

3333
static LlamaModelLoader modelLoader;
3434

35-
36-
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {}
35+
record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map<String, GGUF.GGUFTensorInfo> tensorInfos) {
36+
}
3737

3838
private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF"));
3939

@@ -49,12 +49,8 @@ private static PartialModel preLoadGGUF(String modelPath) {
4949
GGUF gguf = GGUF.loadModel(path);
5050
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
5151
modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false);
52-
return new PartialModel(
53-
path.getFileName().toString(),
54-
modelLoader.loadModel(), // TODO: needs proper handling for AOT
55-
gguf.getTensorDataOffset(),
56-
gguf.getTensorInfos()
57-
);
52+
return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT
53+
gguf.getTensorDataOffset(), gguf.getTensorInfos());
5854
}
5955
} catch (IOException e) {
6056
throw new RuntimeException(e);
@@ -78,8 +74,7 @@ public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IO
7874
return null;
7975
}
8076
Llama baseModel = preLoaded.model();
81-
try (var timer = Timer.log("Load tensors from pre-loaded model");
82-
var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
77+
try (var timer = Timer.log("Load tensors from pre-loaded model"); var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
8378
// Load only the tensors (mmap slices).
8479
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
8580
Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration());

src/main/java/com/example/auxiliary/Utf8Mask.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
/** mask of a byte-sequence in UTF-8 encoding */
44
public record Utf8Mask(int mask, int pattern, int len) {
5+
//@formatter:off
56
public static final Utf8Mask[] MASKS = {
67
new Utf8Mask(0b11100000, 0b11000000, 2),
78
new Utf8Mask(0b11110000, 0b11100000, 3),
89
new Utf8Mask(0b11111000, 0b11110000, 4)
910
};
11+
//@formatter:on
1012
}

src/main/java/com/example/inference/state/LlamaState.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88

99
import java.util.stream.Stream;
1010

11+
/**
12+
* Represents the state of the Llama model during inference.
13+
* This class extends {@link State} to include model-specific functionalities
14+
* and configurations tailored for the Llama model.
15+
*
16+
* <p><b>Note 1:</b> LlamaState contains additional fields for TornadoVM wrappers
17+
* to enable GPU-accelerated processing of the model.</p>
18+
*
19+
* <p><b>Note 2:</b> This state implementation is also used for the Mistral model.</p>
20+
*/
1121
public final class LlamaState extends State {
1222

1323
public LlamaState(Configuration config, int batchsize) {
@@ -56,9 +66,9 @@ protected StateFields createStateFields(Configuration config) {
5666
fields.positionHolder = new IntArray(1);
5767

5868
// Temporary arrays
59-
fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
60-
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
61-
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
69+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
70+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
71+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
6272

6373
return fields;
6474
}

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99

1010
import java.util.stream.Stream;
1111

12+
/**
13+
* Represents the state of the Qwen3 model during inference.
14+
* This class extends {@link State} to include model-specific functionalities
15+
* and configurations tailored for the Qwen3 model.
16+
*
17+
* <p><b>Note 1:</b> Qwen3State contains additional fields for TornadoVM wrappers
18+
* to enable GPU-accelerated processing of the model.</p>
19+
*
20+
*/
1221
public final class Qwen3State extends State {
1322

1423
// Qwen3 specific fields
@@ -52,10 +61,8 @@ protected StateFields createStateFields(Configuration configuration) {
5261
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
5362

5463
// Key-value cache with Qwen3 dimensions
55-
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
56-
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
57-
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa))
58-
.limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
64+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
65+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
5966

6067
// TornadoVM wrappers with Qwen3-specific sizes
6168
fields.wrapX = new FloatArray(config.dim());
@@ -76,9 +83,9 @@ protected StateFields createStateFields(Configuration configuration) {
7683
fields.positionHolder = new IntArray(1);
7784

7885
// Temporary arrays
79-
fields.temp = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
80-
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
81-
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize-1) / localSize));
86+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
87+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
88+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
8289

8390
return fields;
8491
}

src/main/java/com/example/inference/state/State.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,23 @@
66
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
77

88
/**
9-
* Base class for State
9+
* Represents the base state structure used during LLM inference.
10+
* This class provides a common foundation for handling state-related data and functionalities
11+
* that can be extended by model-specific implementations.
12+
*
13+
* <p><b>Key Responsibilities:</b></p>
14+
* <ul>
15+
* <li>Defines core structures to store and access model state data required for computation.</li>
16+
* <li>Can be extended by model-specific state classes (e.g., {@link LlamaState}, {@link Qwen3State}).</li>
17+
* </ul>
18+
*
19+
* <p><b>Usage:</b> Extend `State` to implement model-specific state configurations
20+
* while reusing the common structure and functionality provided by this class.</p>
21+
*
22+
* <p><b>Note:</b> This class is designed to be generic and does not include any
23+
* model-specific behavior or fields. Those should be implemented in subclasses.</p>
1024
*/
11-
public abstract class State{
25+
public abstract class State {
1226

1327
// current wave of activations
1428
public final FloatTensor x; // activation at current time stamp (dim,)

src/main/java/com/example/inference/weights/Weights.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
import com.example.core.model.GGMLType;
44

5+
/**
6+
* The GPULlama3.java utilizes two distinct weight types:
7+
* <ul>
8+
* <li><b>StandardWeights:</b> Designed for standard Java-based inference on the CPU.</li>
9+
* <li><b>TornadoWeights:</b> Optimized for GPU-accelerated inference using TornadoVM.</li>
10+
* </ul>
11+
*
12+
* The packages <code>weights.standard</code> and <code>weights.tornado</code> define
13+
* base classes and model-specific implementations for weights in their respective formats.
14+
*/
515
public interface Weights {
616

717
GGMLType getWeightType();

src/main/java/com/example/inference/weights/standard/LlamaStandardWeights.java

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,69 @@
33
import com.example.core.model.GGMLType;
44
import com.example.core.model.tensor.FloatTensor;
55

6+
/**
7+
* A model-specific implementation of {@link StandardWeights} for the Llama model.
8+
* This class encapsulates the weights required for performing inference
9+
* using the Llama model in the standard CPU-based format.
10+
*
11+
* <p><b>Note:</b> This weight format is also used for the Mistral model.</p>
12+
*/
613
public class LlamaStandardWeights extends StandardWeights {
714

8-
public LlamaStandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatTensor[] rms_ffn_weight,
9-
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
10-
super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
15+
// @formatter:off
16+
/**
17+
* Constructor for LlamaStandardWeights.
18+
*
19+
* @param token_embedding_table The token embedding table tensor.
20+
* @param rms_att_weight Array of RMS attention weights tensors.
21+
* @param wq Array of query weight tensors.
22+
* @param wk Array of key weight tensors.
23+
* @param wv Array of value weight tensors.
24+
* @param wo Array of output weight tensors.
25+
* @param rms_ffn_weight Array of RMS feed-forward network weights.
26+
* @param w1 Array of first feed-forward layer weights.
27+
* @param w2 Array of second feed-forward layer weights.
28+
* @param w3 Array of third feed-forward layer weights.
29+
* @param rms_final_weight Final RMS weight tensor.
30+
* @param freq_cis_real Real part of frequency cis tensor.
31+
* @param freq_cis_imag Imaginary part of frequency cis tensor.
32+
* @param wcls Class token weight tensor.
33+
* @param weightType The GGML weight type.
34+
*/
35+
public LlamaStandardWeights(
36+
FloatTensor token_embedding_table,
37+
FloatTensor[] rms_att_weight,
38+
FloatTensor[] wq,
39+
FloatTensor[] wk,
40+
FloatTensor[] wv,
41+
FloatTensor[] wo,
42+
FloatTensor[] rms_ffn_weight,
43+
FloatTensor[] w1,
44+
FloatTensor[] w2,
45+
FloatTensor[] w3,
46+
FloatTensor rms_final_weight,
47+
FloatTensor freq_cis_real,
48+
FloatTensor freq_cis_imag,
49+
FloatTensor wcls,
50+
GGMLType weightType) {
51+
// call to StandardWeights constructor
52+
super(token_embedding_table,
53+
rms_att_weight,
54+
wq,
55+
wk,
56+
wv,
57+
wo,
58+
rms_ffn_weight,
59+
w1,
60+
w2,
61+
w3,
62+
rms_final_weight,
63+
freq_cis_real,
64+
freq_cis_imag,
65+
wcls,
66+
weightType);
1167
}
68+
// @formatter:on
1269

1370
@Override
1471
public GGMLType getWeightType() {

src/main/java/com/example/inference/weights/standard/Qwen3StandardWeights.java

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,75 @@
33
import com.example.core.model.GGMLType;
44
import com.example.core.model.tensor.FloatTensor;
55

6+
/**
7+
* A model-specific implementation of {@link StandardWeights} for the Qwen-3 model.
8+
* This class defines the weights required for performing inference
9+
* using the Qwen-3 model in the standard CPU-based format.
10+
*/
611
public class Qwen3StandardWeights extends StandardWeights {
712
public final FloatTensor[] attnKNorm, attnQNorm;
813

9-
public Qwen3StandardWeights(FloatTensor token_embedding_table, FloatTensor[] rms_att_weight,
10-
FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
11-
FloatTensor[] attnKNorm, FloatTensor[] attnQNorm,
14+
// @formatter:off
15+
/**
16+
* Constructor for {@code Qwen3StandardWeights}.
17+
*
18+
* @param token_embedding_table The token embedding table, used to map tokens to embeddings.
19+
* @param rms_att_weight The array of Root Mean Square (RMS) attention weights.
20+
* @param wq The array of query weight tensors for attention layers.
21+
* @param wk The array of key weight tensors for attention layers.
22+
* @param wv The array of value weight tensors for attention layers.
23+
* @param wo The array of output weight tensors for attention layers.
24+
* @param attnKNorm The array of normalization tensors for attention keys.
25+
* @param attnQNorm The array of normalization tensors for attention queries.
26+
* @param rms_ffn_weight The array of RMS weights for feed-forward neural network layers.
27+
* @param w1 The array of first weight tensors for feed-forward layers.
28+
* @param w2 The array of second weight tensors for feed-forward layers.
29+
* @param w3 The array of third weight tensors for feed-forward layers.
30+
* @param rms_final_weight The RMS weight used for final output normalization.
31+
* @param freq_cis_real The real part of the frequency position encodings.
32+
* @param freq_cis_imag The imaginary part of the frequency position encodings.
33+
* @param wcls The weight tensor for the classification head.
34+
* @param weightType The type of the weights, defined as {@link GGMLType}.
35+
*/
36+
public Qwen3StandardWeights(
37+
FloatTensor token_embedding_table,
38+
FloatTensor[] rms_att_weight,
39+
FloatTensor[] wq,
40+
FloatTensor[] wk,
41+
FloatTensor[] wv,
42+
FloatTensor[] wo,
43+
FloatTensor[] attnKNorm,
44+
FloatTensor[] attnQNorm,
1245
FloatTensor[] rms_ffn_weight,
13-
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
14-
FloatTensor rms_final_weight, FloatTensor freq_cis_real, FloatTensor freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
46+
FloatTensor[] w1,
47+
FloatTensor[] w2,
48+
FloatTensor[] w3,
49+
FloatTensor rms_final_weight,
50+
FloatTensor freq_cis_real,
51+
FloatTensor freq_cis_imag,
52+
FloatTensor wcls,
53+
GGMLType weightType) {
1554
// call to StandardWeights constructor
16-
super(token_embedding_table, rms_att_weight, wq, wk, wv, wo, rms_ffn_weight, w1, w2, w3, rms_final_weight, freq_cis_real, freq_cis_imag, wcls, weightType);
55+
super(token_embedding_table,
56+
rms_att_weight,
57+
wq,
58+
wk,
59+
wv,
60+
wo,
61+
rms_ffn_weight,
62+
w1,
63+
w2,
64+
w3,
65+
rms_final_weight,
66+
freq_cis_real,
67+
freq_cis_imag,
68+
wcls,
69+
weightType);
70+
// init Qwen3-specific fields
1771
this.attnKNorm = attnKNorm;
1872
this.attnQNorm = attnQNorm;
1973
}
74+
// @formatter:on
2075

2176
@Override
2277
public GGMLType getWeightType() {

0 commit comments

Comments
 (0)