Skip to content

Commit f12c3fb

Browse files
committed
merge embedding code into main ols4 pipeline
1 parent 6748bdf commit f12c3fb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+9199
-179
lines changed

.github/workflows/docker-publish.yml

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,28 @@ jobs:
1717

1818
strategy:
1919
matrix:
20-
component: [dataload, backend, frontend, apitester4]
20+
component: [dataload, embed, backend, frontend, apitester4]
21+
include:
22+
- component: dataload
23+
context: .
24+
dockerfile: ./dataload/Dockerfile
25+
build-args: ""
26+
- component: embed
27+
context: ./dataload/embeddings
28+
dockerfile: ./dataload/embeddings/Dockerfile
29+
build-args: ""
30+
- component: backend
31+
context: .
32+
dockerfile: ./backend/Dockerfile
33+
build-args: ""
34+
- component: frontend
35+
context: ./frontend
36+
dockerfile: ./frontend/Dockerfile
37+
build-args: ""
38+
- component: apitester4
39+
context: ./apitester4
40+
dockerfile: ./apitester4/Dockerfile
41+
build-args: ""
2142

2243
steps:
2344
- name: Checkout repository
@@ -48,10 +69,11 @@ jobs:
4869
- name: Build and push ols4 ${{ matrix.component }} Docker image
4970
uses: docker/build-push-action@v5
5071
with:
51-
context: ${{ matrix.component == 'frontend' && './frontend' || (matrix.component == 'apitester4' && './apitester4' || '.') }}
52-
file: ${{ matrix.component == 'frontend' && './frontend/Dockerfile' || format('./{0}/Dockerfile', matrix.component) }}
53-
platforms: linux/amd64,linux/arm64
72+
context: ${{ matrix.context }}
73+
file: ${{ matrix.dockerfile }}
74+
platforms: ${{ matrix.component == 'embed' && 'linux/amd64' || 'linux/amd64,linux/arm64' }}
5475
push: true
76+
build-args: ${{ matrix.build-args }}
5577
tags: |
5678
ghcr.io/ebispot/ols4-${{ matrix.component }}:${{ github.sha }}
5779
ghcr.io/ebispot/ols4-${{ matrix.component }}:${{ github.ref_name }}

.github/workflows/security-container-scan.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ jobs:
4747
matrix:
4848
image:
4949
- ols4-dataload
50+
- ols4-embed
5051
- ols4-backend
5152
- ols4-frontend
5253

backend/src/main/java/uk/ac/ebi/spot/ols/service/EmbeddingServiceClient.java

Lines changed: 155 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,156 @@
11
package uk.ac.ebi.spot.ols.service;
22

33
import com.google.gson.Gson;
4+
import com.google.gson.JsonArray;
5+
import com.google.gson.JsonElement;
46
import com.google.gson.JsonObject;
57
import org.springframework.beans.factory.annotation.Value;
68
import org.springframework.stereotype.Service;
79

10+
import jakarta.annotation.PostConstruct;
811
import java.io.IOException;
12+
import java.io.Reader;
913
import java.net.URI;
1014
import java.net.http.HttpClient;
1115
import java.net.http.HttpRequest;
1216
import java.net.http.HttpResponse;
17+
import java.nio.file.DirectoryStream;
18+
import java.nio.file.Files;
19+
import java.nio.file.Path;
20+
import java.nio.file.Paths;
1321
import java.time.Duration;
1422
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Set;
25+
import java.util.concurrent.ConcurrentHashMap;
26+
import java.util.regex.Matcher;
27+
import java.util.regex.Pattern;
1528

1629
/**
1730
* Client for the OLS embedding service.
31+
*
32+
* Handles PCA transformations locally: when a PCA model name is requested
33+
* (e.g. "model_pca512"), the client calls the embedding service with the
34+
* base model name ("model") and applies the PCA transform using a JSON
35+
* file loaded from the configured PCA models directory.
1836
*/
1937
@Service
2038
public class EmbeddingServiceClient {
2139

2240
@Value("${ols.embedding.service.url:#{null}}")
2341
private String embeddingServiceUrl;
42+
43+
@Value("${ols.embedding.pca.models.dir:#{null}}")
44+
private String pcaModelsDir;
2445

2546
private final HttpClient httpClient = HttpClient.newBuilder()
2647
.version(HttpClient.Version.HTTP_1_1)
2748
.connectTimeout(Duration.ofSeconds(30))
2849
.build();
2950
private final Gson gson = new Gson();
51+
52+
// PCA model name (e.g. "model_pca512") -> PcaModel
53+
private final Map<String, PcaModel> pcaModels = new ConcurrentHashMap<>();
54+
55+
private static final Pattern PCA_PATTERN = Pattern.compile("^(.+)_pca(\\d+)$");
56+
57+
private static class PcaModel {
58+
final String baseModelName;
59+
final int nComponents;
60+
final double[] mean; // length = n_features
61+
final double[][] components; // shape = (n_features, n_components)
62+
63+
PcaModel(String baseModelName, int nComponents, double[] mean, double[][] components) {
64+
this.baseModelName = baseModelName;
65+
this.nComponents = nComponents;
66+
this.mean = mean;
67+
this.components = components;
68+
}
69+
}
70+
71+
@PostConstruct
72+
public void init() {
73+
loadPcaModels();
74+
}
75+
76+
private void loadPcaModels() {
77+
if (pcaModelsDir == null || pcaModelsDir.isEmpty()) {
78+
return;
79+
}
80+
Path dir = Paths.get(pcaModelsDir);
81+
if (!Files.isDirectory(dir)) {
82+
System.err.println("PCA models directory does not exist: " + pcaModelsDir);
83+
return;
84+
}
85+
86+
try (DirectoryStream<Path> stream = Files.newDirectoryStream(dir, "*_pca*.json")) {
87+
for (Path file : stream) {
88+
String filename = file.getFileName().toString();
89+
// Expected format: {base_model}_pca{n}.json
90+
String stem = filename.replaceFirst("\\.json$", "");
91+
Matcher m = PCA_PATTERN.matcher(stem);
92+
if (!m.matches()) continue;
93+
94+
String baseModelName = m.group(1);
95+
int nComponents = Integer.parseInt(m.group(2));
96+
String pcaModelName = stem;
97+
98+
System.err.println("Loading PCA model: " + pcaModelName + " from " + file);
99+
100+
try (Reader reader = Files.newBufferedReader(file)) {
101+
JsonObject json = gson.fromJson(reader, JsonObject.class);
102+
103+
double[] mean = toDoubleArray(json.getAsJsonArray("mean"));
104+
double[][] components = toDoubleArray2D(json.getAsJsonArray("components"));
105+
106+
pcaModels.put(pcaModelName, new PcaModel(baseModelName, nComponents, mean, components));
107+
System.err.println("Loaded PCA model: " + pcaModelName +
108+
" (base=" + baseModelName + ", components=" + nComponents +
109+
", features=" + mean.length + ")");
110+
}
111+
}
112+
} catch (IOException e) {
113+
System.err.println("Error loading PCA models from " + pcaModelsDir + ": " + e.getMessage());
114+
}
115+
}
116+
117+
private static double[] toDoubleArray(JsonArray arr) {
118+
double[] result = new double[arr.size()];
119+
for (int i = 0; i < arr.size(); i++) {
120+
result[i] = arr.get(i).getAsDouble();
121+
}
122+
return result;
123+
}
124+
125+
private static double[][] toDoubleArray2D(JsonArray arr) {
126+
double[][] result = new double[arr.size()][];
127+
for (int i = 0; i < arr.size(); i++) {
128+
result[i] = toDoubleArray(arr.get(i).getAsJsonArray());
129+
}
130+
return result;
131+
}
132+
133+
/**
134+
* Apply PCA transform: (x - mean) @ components
135+
*/
136+
private float[] applyPca(float[] embedding, PcaModel pca) {
137+
int nFeatures = pca.mean.length;
138+
int nComponents = pca.nComponents;
139+
float[] result = new float[nComponents];
140+
141+
for (int j = 0; j < nComponents; j++) {
142+
double sum = 0.0;
143+
for (int i = 0; i < nFeatures; i++) {
144+
sum += ((double) embedding[i] - pca.mean[i]) * pca.components[i][j];
145+
}
146+
result[j] = (float) sum;
147+
}
148+
return result;
149+
}
30150

31151
/**
32152
* Get list of available models from the embedding service.
33-
* Queries the /models endpoint to get the current list.
153+
* Includes PCA model variants loaded from JSON files.
34154
*/
35155
public List<String> getAvailableModels() {
36156

@@ -47,45 +167,63 @@ public List<String> getAvailableModels() {
47167

48168
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
49169

170+
Set<String> serviceModels = new java.util.HashSet<String>();
50171
if (response.statusCode() == 200) {
51172
JsonObject json = gson.fromJson(response.body(), JsonObject.class);
52173
if (json.has("models") && json.get("models").isJsonArray()) {
53-
List<String> models = new java.util.ArrayList<>();
54174
json.getAsJsonArray("models").forEach(element -> {
55175
if (element.isJsonPrimitive()) {
56-
models.add(element.getAsString());
176+
serviceModels.add(element.getAsString());
57177
}
58178
});
59-
return models;
60179
}
61180
}
62-
// Fallback to empty list if service is unavailable
63-
return List.of();
181+
182+
List<String> models = new java.util.ArrayList<>(serviceModels);
183+
184+
// Only include PCA models whose base model is available in the service
185+
for (var entry : pcaModels.entrySet()) {
186+
if (serviceModels.contains(entry.getValue().baseModelName)) {
187+
models.add(entry.getKey());
188+
}
189+
}
190+
191+
return models;
64192
} catch (Exception e) {
65-
// Service unavailable, return empty list
66193
return List.of();
67194
}
68195
}
69196

70197
/**
71-
* Embed a single text using the new embedding service.
72-
* @param model The model name to use for embedding
73-
* @param text The text to embed
74-
* @return The embedding vector as a float array
198+
* Embed a single text. If the model name is a PCA model (e.g. "model_pca512"),
199+
* embeds with the base model and applies the PCA transform locally.
75200
*/
76201
public float[] embedText(String model, String text) throws IOException {
77202
return embedTexts(model, List.of(text))[0];
78203
}
79204

80205
/**
81-
* Embed multiple texts using the new embedding service.
82-
* The service returns binary blob of float32 arrays.
83-
* @param model The model name to use for embedding
84-
* @param texts List of texts to embed
85-
* @return Array of embedding vectors
206+
* Embed multiple texts. If the model name is a PCA model, embeds with the
207+
* base model and applies the PCA transform locally.
86208
*/
87209
public float[][] embedTexts(String model, List<String> texts) throws IOException {
88210

211+
PcaModel pca = pcaModels.get(model);
212+
String serviceModel = (pca != null) ? pca.baseModelName : model;
213+
214+
float[][] embeddings = embedTextsFromService(serviceModel, texts);
215+
216+
if (pca != null) {
217+
for (int i = 0; i < embeddings.length; i++) {
218+
embeddings[i] = applyPca(embeddings[i], pca);
219+
}
220+
}
221+
222+
return embeddings;
223+
}
224+
225+
private float[][] embedTextsFromService(String model, List<String> texts) throws IOException {
226+
89227
if(embeddingServiceUrl == null || embeddingServiceUrl.isEmpty()) {
90228
throw new IOException("Embedding service URL is not configured");
91229
}
@@ -104,25 +242,17 @@ public float[][] embedTexts(String model, List<String> texts) throws IOException
104242
.build();
105243

106244
try {
107-
System.err.println("Embedding service request URL: " + embeddingServiceUrl);
108-
System.err.println("Request body: " + requestBodyJson);
109-
110245
HttpResponse<byte[]> response = httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
111246

112-
System.err.println("Response status: " + response.statusCode());
113-
System.err.println("Response headers: " + response.headers().map());
114-
115247
if (response.statusCode() == 200) {
116-
// Get vector dimension from header
117248
String dimHeader = response.headers().firstValue("x-embedding-dim").orElse(null);
118249
if (dimHeader == null) {
119250
throw new IOException("Missing x-embedding-dim header in response");
120251
}
121252
int dimension = Integer.parseInt(dimHeader);
122253

123-
// Parse binary blob as float32 array
124254
byte[] binaryData = response.body();
125-
int expectedBytes = texts.size() * dimension * 4; // 4 bytes per float
255+
int expectedBytes = texts.size() * dimension * 4;
126256

127257
if (binaryData.length != expectedBytes) {
128258
throw new IOException("Unexpected response size: got " + binaryData.length +

dataload.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ else
3939
echo "Using OLS4_DATALOAD_IMAGE: $OLS4_DATALOAD_IMAGE"
4040
fi
4141

42+
if [ -z "${OLS4_EMBED_IMAGE:-}" ]; then
43+
echo "OLS4_EMBED_IMAGE environment variable is not set. Using dev image."
44+
OLS4_EMBED_IMAGE="ghcr.io/ebispot/ols4-embed:dev"
45+
else
46+
echo "Using OLS4_EMBED_IMAGE: $OLS4_EMBED_IMAGE"
47+
fi
48+
4249
TMP_DIR="$OLS_HOME/tmp"
4350
OUT_DIR="$OLS_HOME/out"
4451

@@ -64,7 +71,10 @@ docker run \
6471
-e OLS4_CONFIG="$OLS4_CONFIG" \
6572
-e OLS4_DATALOAD_ARGS="${OLS4_DATALOAD_ARGS:-}" \
6673
-e OLS_EMBEDDINGS_PATH="$OLS_EMBEDDINGS_PATH" \
74+
-e OLS_EMBEDDINGS_CONFIG="${OLS_EMBEDDINGS_CONFIG:-}" \
75+
-e OLS_EMBEDDINGS_PREV="${OLS_EMBEDDINGS_PREV:-}" \
6776
-e OLS4_DATALOAD_IMAGE="$OLS4_DATALOAD_IMAGE" \
77+
-e OLS4_EMBED_IMAGE="$OLS4_EMBED_IMAGE" \
6878
-e NXF_USRMAP="${HOST_UID}" \
6979
-e HOST_UID="${HOST_UID}" \
7080
-e HOST_GID="${HOST_GID}" \

0 commit comments

Comments
 (0)