Skip to content

Commit bd580e2

Browse files
committed
feat: finished model serialization API
1 parent 5e2220e commit bd580e2

24 files changed

+383
-119
lines changed

brain4j-core/src/main/java/org/brain4j/core/importing/ModelLoaders.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
package org.brain4j.core.importing;
22

3+
import org.brain4j.core.importing.impl.BrainLoader;
34
import org.brain4j.core.importing.impl.OnnxLoader;
45
import org.brain4j.core.model.Model;
56

7+
import java.io.IOException;
68
import java.nio.file.Files;
79
import java.nio.file.Paths;
810

911
public class ModelLoaders {
1012

13+
public static Model fromFile(String path) throws Exception {
14+
byte[] data = Files.readAllBytes(Paths.get(path));
15+
16+
BrainLoader loader = new BrainLoader();
17+
18+
return loader.deserialize(data);
19+
}
20+
1121
public static Model fromOnnx(String path) throws Exception {
1222
byte[] data = Files.readAllBytes(Paths.get(path));
1323

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package org.brain4j.core.importing.impl;
2+
3+
import org.brain4j.common.Commons;
4+
import org.brain4j.core.importing.ModelLoader;
5+
import org.brain4j.core.importing.proto.ProtoModel;
6+
import org.brain4j.core.layer.Layer;
7+
import org.brain4j.core.loss.LossFunction;
8+
import org.brain4j.core.model.Model;
9+
import org.brain4j.core.model.impl.Sequential;
10+
11+
import java.io.*;
12+
import java.lang.reflect.Constructor;
13+
import java.time.Instant;
14+
import java.util.*;
15+
16+
public class BrainLoader implements ModelLoader {
17+
18+
@Override
19+
public Model deserialize(byte[] bytes) throws Exception {
20+
ProtoModel.Model protoModel = ProtoModel.Model.parseFrom(bytes);
21+
Map<Integer, Layer> positionMap = new HashMap<>();
22+
23+
for (ProtoModel.Layer layer : protoModel.getLayersList()) {
24+
String layerType = layer.getType();
25+
String layerId = layer.getName();
26+
27+
String[] parts = layerId.split("\\.");
28+
29+
if (parts.length == 0) {
30+
throw new IllegalArgumentException("Layer does not match format!");
31+
}
32+
33+
int position = Integer.parseInt(parts[1]);
34+
35+
Class<?> clazz = Class.forName(layerType);
36+
37+
Constructor<?> constructor = clazz.getDeclaredConstructor();
38+
constructor.setAccessible(true);
39+
40+
Layer wrapped = (Layer) constructor.newInstance();
41+
List<ProtoModel.Tensor> tensors = new ArrayList<>();
42+
43+
for (ProtoModel.Tensor tensor : protoModel.getWeightsList()) {
44+
if (!tensor.getName().startsWith(layerId)) continue;
45+
46+
tensors.add(tensor);
47+
}
48+
49+
positionMap.put(position, wrapped);
50+
wrapped.deserialize(tensors, layer);
51+
}
52+
53+
List<Integer> positions = new ArrayList<>(positionMap.keySet());
54+
Collections.sort(positions);
55+
56+
Sequential model = Sequential.of();
57+
58+
for (int pos : positions) {
59+
model.add(positionMap.get(pos));
60+
}
61+
62+
String lossFunctionClass = protoModel.getLossFunction();
63+
64+
LossFunction function = Commons.newInstance(lossFunctionClass);
65+
model.setLossFunction(function);
66+
67+
return model;
68+
}
69+
70+
@Override
71+
public void serialize(Model model, File file) throws IOException {
72+
ProtoModel.Model.Builder builder =
73+
ProtoModel.Model.newBuilder()
74+
.setVersion(1)
75+
.setName(file.getName())
76+
.setCreated(Instant.now().toString())
77+
.setLossFunction(model.lossFunction().getClass().getName());
78+
79+
List<Layer> layers = model.layers();
80+
81+
for (int i = 0; i < layers.size(); i++) {
82+
Layer layer = layers.get(i);
83+
String name = layer.getClass().getSimpleName().toLowerCase();
84+
String id = name + "." + i;
85+
86+
ProtoModel.Layer.Builder layerBuilder =
87+
ProtoModel.Layer.newBuilder()
88+
.setName(id)
89+
.setType(layer.getClass().getName())
90+
.setDimension(layer.size());
91+
92+
List<ProtoModel.Tensor.Builder> tensorsBuilders = layer.serialize(layerBuilder);
93+
List<ProtoModel.Tensor> tensors = new ArrayList<>();
94+
95+
for (ProtoModel.Tensor.Builder tensorBuilder : tensorsBuilders) {
96+
tensorBuilder.setName(id + "." + tensorBuilder.getName());
97+
tensors.add(tensorBuilder.build());
98+
}
99+
100+
builder.addLayers(layerBuilder.build());
101+
builder.addAllWeights(tensors);
102+
}
103+
104+
builder.build().writeTo(new FileOutputStream(file));
105+
}
106+
}

brain4j-core/src/main/java/org/brain4j/core/importing/impl/BrainProtocol.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)