Skip to content

Commit 4323adb

Browse files
committed
chore(refactor): small refactoring to names and tensor logic
fix(transformer): fixed gate projection throwing NPE when freezing/unfreezing
1 parent 84bab6e commit 4323adb

File tree

19 files changed

+97
-60
lines changed

19 files changed

+97
-60
lines changed

brain4j-core/src/main/java/org/brain4j/core/Brain4J.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ public static Device firstDevice() {
123123
throw new IllegalStateException("No GPU-acceleration device has been found!");
124124
}
125125

126-
return DeviceUtils.findDevice(devices.getFirst());
126+
Device device = DeviceUtils.findDevice(devices.getFirst());
127+
128+
if (device != null) Brain4J.initKernels(device);
129+
130+
return device;
127131
}
128132

129133
/**

brain4j-core/src/main/java/org/brain4j/core/importing/SafeTensorsConverter.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ private static Map<String, Tensor> load(ByteBuffer buffer) throws IOException {
114114

115115
for (Map.Entry<String, JsonElement> entry : header.entrySet()) {
116116
String name = entry.getKey();
117+
118+
if (name.equals("__metadata__")) continue;
119+
117120
JsonObject info = entry.getValue().getAsJsonObject();
118121

119122
JsonArray shapeArray = info.getAsJsonArray("shape");

brain4j-core/src/main/java/org/brain4j/core/layer/impl/RecurrentLayer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public Tensor[] forward(StatesCache cache, Tensor... inputs) {
7575
for (int t = 0; t < timesteps; t++) {
7676
Range[] ranges = new Range[] { Range.all(), Range.point(t), Range.all() };
7777

78-
Tensor timestepX = projectedInput.sliceGrad(ranges).squeeze(1);
78+
Tensor timestepX = projectedInput.sliceGrad(ranges).squeezeGrad(1);
7979
Tensor timestepH = hiddenState.matmulGrad(hiddenWeights);
8080

8181
hiddenState = timestepX.addGrad(timestepH).addGrad(hiddenBias).activateGrad(activation);

brain4j-core/src/main/java/org/brain4j/core/layer/impl/transformer/MultiHeadAttention.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ public Tensor[] forward(StatesCache cache, Tensor... inputs) {
133133
} else {
134134
Tensor reshaped = QKV.reshape(batch, seqLength, H, 3, d)
135135
.transpose(1, 2); // [B,H,L,3,d]
136-
Q = reshaped.slice(all, all, all, Range.point(0), all).squeeze(3);
137-
K = reshaped.slice(all, all, all, Range.point(1), all).squeeze(3);
138-
V = reshaped.slice(all, all, all, Range.point(2), all).squeeze(3);
136+
Q = reshaped.slice(all, all, all, Range.point(0), all).squeezeGrad(3);
137+
K = reshaped.slice(all, all, all, Range.point(1), all).squeezeGrad(3);
138+
V = reshaped.slice(all, all, all, Range.point(2), all).squeezeGrad(3);
139139
}
140140

141141
float scale = (float) (1.0 / Math.sqrt(d));

brain4j-core/src/main/java/org/brain4j/core/layer/impl/transformer/TransformerEncoder.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,24 +283,26 @@ public void loadWeights(Map<String, Tensor> mappedWeights) {
283283
@Override
284284
public Layer freeze() {
285285
upProjection.freeze();
286-
gateProjection.freeze();
287-
gateProjection.freeze();
288286
downProjection.freeze();
289287
normalizer1.freeze();
290288
normalizer2.freeze();
291289
attention.freeze();
290+
291+
if (gateProjection != null) gateProjection.freeze();
292+
292293
return super.freeze();
293294
}
294295

295296
@Override
296297
public Layer unfreeze() {
297298
upProjection.unfreeze();
298-
gateProjection.unfreeze();
299-
gateProjection.unfreeze();
300299
downProjection.unfreeze();
301300
normalizer1.unfreeze();
302301
normalizer2.unfreeze();
303302
attention.unfreeze();
303+
304+
if (gateProjection != null) gateProjection.unfreeze();
305+
304306
return super.unfreeze();
305307
}
306308

brain4j-core/src/main/java/org/brain4j/core/model/impl/Sequential.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ public Tensor[] predict(StatesCache cache, Tensor... inputs) {
5959
input = input.reshape(1, input.elements()); // reshape to [batch, input_size]
6060
}
6161

62-
buffer[i] = cache.isTraining() ? input.withGrad() : input;
62+
Tensor chosen = cache.isTraining() ? input.withGrad() : input;
63+
buffer[i] = chosen.to(device);
6364
}
6465

6566
for (Layer layer : layers) {

brain4j-core/src/main/java/org/brain4j/core/transformer/attention/MaskedMultiHeadAttention.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public Tensor[] forward(StatesCache cache, Tensor... inputs) {
114114
}
115115

116116
Range[] slicingRanges = {
117-
Range.all(), Range.point(seqLength - 1), Range.all()
117+
Range.all(), Range.point(seqLength - 1), Range.all()
118118
}; // [batch, 1, dim]
119119
Tensor cachedOutput = cache.get(outProj);
120120
Tensor cachedQKV = cache.get(weights);

brain4j-core/src/main/java/org/brain4j/core/transformer/attention/head/FlashAttentionHead.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ public Tensor attend(Tensor input) {
6666
}
6767

6868
if (context != null) {
69-
return training
70-
? context.squeezeGrad(1) // [B,L,d]
71-
: context.squeeze(1);
69+
return context.squeezeGrad(1); // [B,L,d]
7270
}
73-
// fallthrough if context null
7471
}
7572

7673
// fallback to standard path with autograd support

brain4j-llm/src/main/java/org/brain4j/llm/core/architecture/impl/GPT2Adapter.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.brain4j.core.layer.impl.transformer.TransformerDecoder;
1111
import org.brain4j.core.layer.impl.utility.InputLayer;
1212
import org.brain4j.core.model.Model;
13+
import org.brain4j.core.model.ModelSpecs;
1314
import org.brain4j.llm.core.architecture.ArchitectureAdapter;
1415
import org.brain4j.math.data.StatesCache;
1516
import org.brain4j.math.tensor.Tensor;
@@ -31,7 +32,7 @@ public Model buildModel(JsonObject config, Map<String, Tensor> weights) {
3132
int context = config.get("n_ctx").getAsInt();
3233
int vocabSize = config.get("vocab_size").getAsInt();
3334

34-
OldSequential seq = OldSequential.of();
35+
ModelSpecs specs = ModelSpecs.of();
3536

3637
Tensor embedding = weights.get("wte.weight"); // embedding -> [vocab, dim]
3738
Tensor posEncode = weights.get("wpe.weight"); // pos encode -> [length, dim]
@@ -44,9 +45,9 @@ public Model buildModel(JsonObject config, Map<String, Tensor> weights) {
4445
vocabLayer.setWeights(embedding.transpose());
4546
posEncodeLayer.setWeights(posEncode);
4647

47-
seq.add(new InputLayer(-1));
48-
seq.add(embeddingLayer);
49-
seq.add(posEncodeLayer);
48+
specs.add(new InputLayer(-1).freeze());
49+
specs.add(embeddingLayer.freeze());
50+
specs.add(posEncodeLayer.freeze());
5051

5152
for (int i = 0; i < layers; i++) {
5253
String prefix = String.format("h.%s.", i);
@@ -91,7 +92,7 @@ public Model buildModel(JsonObject config, Map<String, Tensor> weights) {
9192
attention.setOutProj(attnOutWeight);
9293
attention.setOutBias(attnOutBias);
9394

94-
seq.add(decoder);
95+
specs.add(decoder.freeze());
9596
}
9697

9798
TokenSelectionLayer selectionLayer = new TokenSelectionLayer();
@@ -103,11 +104,11 @@ public Model buildModel(JsonObject config, Map<String, Tensor> weights) {
103104
normLayer.setWeights(lnGamma);
104105
normLayer.setBias(lnBeta);
105106

106-
seq.add(normLayer);
107-
seq.add(selectionLayer);
108-
seq.add(vocabLayer);
107+
specs.add(normLayer.freeze());
108+
specs.add(selectionLayer.freeze());
109+
specs.add(vocabLayer.freeze());
109110

110-
return seq;
111+
return specs.compile();
111112
}
112113

113114
static class TokenSelectionLayer extends Layer {
@@ -118,7 +119,7 @@ public Tensor[] forward(StatesCache cache, Tensor... inputs) {
118119
Tensor input = inputs[0]; // [batch, seq_len, dim]
119120
int seqLength = input.shapeAt(1);
120121

121-
return new Tensor[] { input.slice(Range.all(), Range.point(seqLength - 1), Range.all()).squeeze(1) };
122+
return new Tensor[] { input.slice(Range.all(), Range.point(seqLength - 1), Range.all()).squeezeGrad(1) };
122123
}
123124

124125
@Override

brain4j-llm/src/main/java/org/brain4j/llm/core/loader/ModelLoader.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.brain4j.llm.core.model.LLM;
1212
import org.brain4j.llm.download.callback.ProgressCallback;
1313
import org.brain4j.llm.download.manager.DownloadManager;
14-
import org.brain4j.math.commons.result.Result;
1514
import org.slf4j.Logger;
1615
import org.slf4j.LoggerFactory;
1716

@@ -56,7 +55,7 @@ public Tokenizer loadTokenizer(String tokenizerId) throws Exception {
5655
public Tokenizer loadTokenizer(String tokenizerId, LoadConfig config) throws Exception {
5756
log.info("Loading tokenizer: {}", tokenizerId);
5857

59-
ModelInfo info = client.getModelInfo(tokenizerId).unwrap();
58+
ModelInfo info = client.getModelInfo(tokenizerId);
6059
log.debug("Tokenizer info retrieved for: {} (resolved id: {})", tokenizerId, info.id());
6160

6261
String fileToDownload = "tokenizer.json";
@@ -75,7 +74,7 @@ public Tokenizer loadTokenizer(String tokenizerId, LoadConfig config) throws Exc
7574
public LLM loadModel(String modelId, LoadConfig config) throws Exception {
7675
log.info("Loading model: {}", modelId);
7776

78-
ModelInfo info = client.getModelInfo(modelId).unwrap();
77+
ModelInfo info = client.getModelInfo(modelId);
7978
log.debug("Model info retrieved for: {} (resolved id: {})", modelId, info.id());
8079

8180
List<String> filesToDownload = determineFilesToDownload(info, config);

0 commit comments

Comments
 (0)