11package uk .ac .ebi .spot .ols .service ;
22
33import com .google .gson .Gson ;
4+ import com .google .gson .JsonArray ;
5+ import com .google .gson .JsonElement ;
46import com .google .gson .JsonObject ;
57import org .springframework .beans .factory .annotation .Value ;
68import org .springframework .stereotype .Service ;
79
10+ import jakarta .annotation .PostConstruct ;
811import java .io .IOException ;
12+ import java .io .Reader ;
913import java .net .URI ;
1014import java .net .http .HttpClient ;
1115import java .net .http .HttpRequest ;
1216import java .net .http .HttpResponse ;
17+ import java .nio .file .DirectoryStream ;
18+ import java .nio .file .Files ;
19+ import java .nio .file .Path ;
20+ import java .nio .file .Paths ;
1321import java .time .Duration ;
1422import java .util .List ;
23+ import java .util .Map ;
24+ import java .util .Set ;
25+ import java .util .concurrent .ConcurrentHashMap ;
26+ import java .util .regex .Matcher ;
27+ import java .util .regex .Pattern ;
1528
1629/**
1730 * Client for the OLS embedding service.
31+ *
32+ * Handles PCA transformations locally: when a PCA model name is requested
33+ * (e.g. "model_pca512"), the client calls the embedding service with the
34+ * base model name ("model") and applies the PCA transform using a JSON
35+ * file loaded from the configured PCA models directory.
1836 */
1937@ Service
2038public class EmbeddingServiceClient {
2139
2240 @ Value ("${ols.embedding.service.url:#{null}}" )
2341 private String embeddingServiceUrl ;
42+
43+ @ Value ("${ols.embedding.pca.models.dir:#{null}}" )
44+ private String pcaModelsDir ;
2445
2546 private final HttpClient httpClient = HttpClient .newBuilder ()
2647 .version (HttpClient .Version .HTTP_1_1 )
2748 .connectTimeout (Duration .ofSeconds (30 ))
2849 .build ();
2950 private final Gson gson = new Gson ();
51+
52+ // PCA model name (e.g. "model_pca512") -> PcaModel
53+ private final Map <String , PcaModel > pcaModels = new ConcurrentHashMap <>();
54+
55+ private static final Pattern PCA_PATTERN = Pattern .compile ("^(.+)_pca(\\ d+)$" );
56+
57+ private static class PcaModel {
58+ final String baseModelName ;
59+ final int nComponents ;
60+ final double [] mean ; // length = n_features
61+ final double [][] components ; // shape = (n_features, n_components)
62+
63+ PcaModel (String baseModelName , int nComponents , double [] mean , double [][] components ) {
64+ this .baseModelName = baseModelName ;
65+ this .nComponents = nComponents ;
66+ this .mean = mean ;
67+ this .components = components ;
68+ }
69+ }
70+
71+ @ PostConstruct
72+ public void init () {
73+ loadPcaModels ();
74+ }
75+
76+ private void loadPcaModels () {
77+ if (pcaModelsDir == null || pcaModelsDir .isEmpty ()) {
78+ return ;
79+ }
80+ Path dir = Paths .get (pcaModelsDir );
81+ if (!Files .isDirectory (dir )) {
82+ System .err .println ("PCA models directory does not exist: " + pcaModelsDir );
83+ return ;
84+ }
85+
86+ try (DirectoryStream <Path > stream = Files .newDirectoryStream (dir , "*_pca*.json" )) {
87+ for (Path file : stream ) {
88+ String filename = file .getFileName ().toString ();
89+ // Expected format: {base_model}_pca{n}.json
90+ String stem = filename .replaceFirst ("\\ .json$" , "" );
91+ Matcher m = PCA_PATTERN .matcher (stem );
92+ if (!m .matches ()) continue ;
93+
94+ String baseModelName = m .group (1 );
95+ int nComponents = Integer .parseInt (m .group (2 ));
96+ String pcaModelName = stem ;
97+
98+ System .err .println ("Loading PCA model: " + pcaModelName + " from " + file );
99+
100+ try (Reader reader = Files .newBufferedReader (file )) {
101+ JsonObject json = gson .fromJson (reader , JsonObject .class );
102+
103+ double [] mean = toDoubleArray (json .getAsJsonArray ("mean" ));
104+ double [][] components = toDoubleArray2D (json .getAsJsonArray ("components" ));
105+
106+ pcaModels .put (pcaModelName , new PcaModel (baseModelName , nComponents , mean , components ));
107+ System .err .println ("Loaded PCA model: " + pcaModelName +
108+ " (base=" + baseModelName + ", components=" + nComponents +
109+ ", features=" + mean .length + ")" );
110+ }
111+ }
112+ } catch (IOException e ) {
113+ System .err .println ("Error loading PCA models from " + pcaModelsDir + ": " + e .getMessage ());
114+ }
115+ }
116+
117+ private static double [] toDoubleArray (JsonArray arr ) {
118+ double [] result = new double [arr .size ()];
119+ for (int i = 0 ; i < arr .size (); i ++) {
120+ result [i ] = arr .get (i ).getAsDouble ();
121+ }
122+ return result ;
123+ }
124+
125+ private static double [][] toDoubleArray2D (JsonArray arr ) {
126+ double [][] result = new double [arr .size ()][];
127+ for (int i = 0 ; i < arr .size (); i ++) {
128+ result [i ] = toDoubleArray (arr .get (i ).getAsJsonArray ());
129+ }
130+ return result ;
131+ }
132+
133+ /**
134+ * Apply PCA transform: (x - mean) @ components
135+ */
136+ private float [] applyPca (float [] embedding , PcaModel pca ) {
137+ int nFeatures = pca .mean .length ;
138+ int nComponents = pca .nComponents ;
139+ float [] result = new float [nComponents ];
140+
141+ for (int j = 0 ; j < nComponents ; j ++) {
142+ double sum = 0.0 ;
143+ for (int i = 0 ; i < nFeatures ; i ++) {
144+ sum += ((double ) embedding [i ] - pca .mean [i ]) * pca .components [i ][j ];
145+ }
146+ result [j ] = (float ) sum ;
147+ }
148+ return result ;
149+ }
30150
31151 /**
32152 * Get list of available models from the embedding service.
33- * Queries the /models endpoint to get the current list .
153+ * Includes PCA model variants loaded from JSON files .
34154 */
35155 public List <String > getAvailableModels () {
36156
@@ -47,45 +167,63 @@ public List<String> getAvailableModels() {
47167
48168 HttpResponse <String > response = httpClient .send (request , HttpResponse .BodyHandlers .ofString ());
49169
170+ Set <String > serviceModels = new java .util .HashSet <String >();
50171 if (response .statusCode () == 200 ) {
51172 JsonObject json = gson .fromJson (response .body (), JsonObject .class );
52173 if (json .has ("models" ) && json .get ("models" ).isJsonArray ()) {
53- List <String > models = new java .util .ArrayList <>();
54174 json .getAsJsonArray ("models" ).forEach (element -> {
55175 if (element .isJsonPrimitive ()) {
56- models .add (element .getAsString ());
176+ serviceModels .add (element .getAsString ());
57177 }
58178 });
59- return models ;
60179 }
61180 }
62- // Fallback to empty list if service is unavailable
63- return List .of ();
181+
182+ List <String > models = new java .util .ArrayList <>(serviceModels );
183+
184+ // Only include PCA models whose base model is available in the service
185+ for (var entry : pcaModels .entrySet ()) {
186+ if (serviceModels .contains (entry .getValue ().baseModelName )) {
187+ models .add (entry .getKey ());
188+ }
189+ }
190+
191+ return models ;
64192 } catch (Exception e ) {
65- // Service unavailable, return empty list
66193 return List .of ();
67194 }
68195 }
69196
70197 /**
71- * Embed a single text using the new embedding service.
72- * @param model The model name to use for embedding
73- * @param text The text to embed
74- * @return The embedding vector as a float array
198+ * Embed a single text. If the model name is a PCA model (e.g. "model_pca512"),
199+ * embeds with the base model and applies the PCA transform locally.
75200 */
76201 public float [] embedText (String model , String text ) throws IOException {
77202 return embedTexts (model , List .of (text ))[0 ];
78203 }
79204
80205 /**
81- * Embed multiple texts using the new embedding service.
82- * The service returns binary blob of float32 arrays.
83- * @param model The model name to use for embedding
84- * @param texts List of texts to embed
85- * @return Array of embedding vectors
206+ * Embed multiple texts. If the model name is a PCA model, embeds with the
207+ * base model and applies the PCA transform locally.
86208 */
87209 public float [][] embedTexts (String model , List <String > texts ) throws IOException {
88210
211+ PcaModel pca = pcaModels .get (model );
212+ String serviceModel = (pca != null ) ? pca .baseModelName : model ;
213+
214+ float [][] embeddings = embedTextsFromService (serviceModel , texts );
215+
216+ if (pca != null ) {
217+ for (int i = 0 ; i < embeddings .length ; i ++) {
218+ embeddings [i ] = applyPca (embeddings [i ], pca );
219+ }
220+ }
221+
222+ return embeddings ;
223+ }
224+
225+ private float [][] embedTextsFromService (String model , List <String > texts ) throws IOException {
226+
89227 if (embeddingServiceUrl == null || embeddingServiceUrl .isEmpty ()) {
90228 throw new IOException ("Embedding service URL is not configured" );
91229 }
@@ -104,25 +242,17 @@ public float[][] embedTexts(String model, List<String> texts) throws IOException
104242 .build ();
105243
106244 try {
107- System .err .println ("Embedding service request URL: " + embeddingServiceUrl );
108- System .err .println ("Request body: " + requestBodyJson );
109-
110245 HttpResponse <byte []> response = httpClient .send (request , HttpResponse .BodyHandlers .ofByteArray ());
111246
112- System .err .println ("Response status: " + response .statusCode ());
113- System .err .println ("Response headers: " + response .headers ().map ());
114-
115247 if (response .statusCode () == 200 ) {
116- // Get vector dimension from header
117248 String dimHeader = response .headers ().firstValue ("x-embedding-dim" ).orElse (null );
118249 if (dimHeader == null ) {
119250 throw new IOException ("Missing x-embedding-dim header in response" );
120251 }
121252 int dimension = Integer .parseInt (dimHeader );
122253
123- // Parse binary blob as float32 array
124254 byte [] binaryData = response .body ();
125- int expectedBytes = texts .size () * dimension * 4 ; // 4 bytes per float
255+ int expectedBytes = texts .size () * dimension * 4 ;
126256
127257 if (binaryData .length != expectedBytes ) {
128258 throw new IOException ("Unexpected response size: got " + binaryData .length +
0 commit comments