33import numpy as np
44import pandas as pd
55from openprotein .base import APISession
6- from openprotein .schemas import Job , PredictorMetadata
6+ from openprotein .schemas import (
7+ CVJob ,
8+ Job ,
9+ PredictJob ,
10+ PredictMultiJob ,
11+ PredictMultiSingleSiteJob ,
12+ PredictorMetadata ,
13+ PredictSingleSiteJob ,
14+ TrainJob ,
15+ )
716from pydantic import TypeAdapter
817
918PATH_PREFIX = "v1/predictor"
@@ -99,13 +108,33 @@ def predictor_fit_gp_post(
99108 body ["description" ] = description
100109
101110 response = session .post (endpoint , json = body )
102- return Job .model_validate (response .json ())
111+ return TrainJob .model_validate (response .json ())
103112
104113
105114def predictor_delete (session : APISession , predictor_id : str ):
106115 raise NotImplementedError ()
107116
108117
118+ def predictor_crossvalidate_post (
119+ session : APISession , predictor_id : str , n_splits : int | None = None
120+ ):
121+ endpoint = PATH_PREFIX + f"/{ predictor_id } /crossvalidate"
122+
123+ params = {}
124+ if n_splits is not None :
125+ params ["n_splits" ] = n_splits
126+ response = session .post (endpoint , params = params )
127+
128+ return CVJob .model_validate (response .json ())
129+
130+
131+ def predictor_crossvalidate_get (session : APISession , crossvalidate_job_id : str ):
132+ endpoint = PATH_PREFIX + f"/crossvalidate/{ crossvalidate_job_id } "
133+
134+ response = session .get (endpoint )
135+ return response .content
136+
137+
109138def predictor_predict_post (
110139 session : APISession , predictor_id : str , sequences : list [bytes ] | list [str ]
111140):
@@ -117,7 +146,25 @@ def predictor_predict_post(
117146 }
118147 response = session .post (endpoint , json = body )
119148
120- return Job .model_validate (response .json ())
149+ return PredictJob .model_validate (response .json ())
150+
151+
152+ def predictor_predict_single_site_post (
153+ session : APISession ,
154+ predictor_id : str ,
155+ base_sequence : bytes | str ,
156+ ):
157+ endpoint = PATH_PREFIX + f"/{ predictor_id } /predict_single_site"
158+
159+ base_sequence = (
160+ base_sequence .decode () if isinstance (base_sequence , bytes ) else base_sequence
161+ )
162+ body = {
163+ "base_sequence" : base_sequence ,
164+ }
165+ response = session .post (endpoint , json = body )
166+
167+ return PredictSingleSiteJob .model_validate (response .json ())
121168
122169
123170def predictor_predict_get_sequences (
@@ -179,9 +226,9 @@ def predictor_predict_get_batched_result(
179226 return response .content
180227
181228
182- def decode_score (data : bytes , batched : bool = False ) -> tuple [np .ndarray , np .ndarray ]:
229+ def decode_predict (data : bytes , batched : bool = False ) -> tuple [np .ndarray , np .ndarray ]:
183230 """
184- Decode embedding .
231+ Decode prediction scores .
185232
186233 Args:
187234 data (bytes): raw bytes encoding the array received over the API
@@ -203,3 +250,25 @@ def decode_score(data: bytes, batched: bool = False) -> tuple[np.ndarray, np.nda
203250 mus = scores [:, ::2 ]
204251 vars = scores [:, 1 ::2 ]
205252 return mus , vars
253+
254+
255+ def decode_crossvalidate (data : bytes ) -> tuple [np .ndarray , np .ndarray , np .ndarray ]:
256+ """
257+ Decode crossvalidate scores.
258+
259+ Args:
260+ data (bytes): raw bytes encoding the array received over the API
261+
262+ Returns:
263+ mus (np.ndarray): decoded array of means
264+ vars (np.ndarray): decoded array of variances
265+ """
266+ s = io .BytesIO (data )
267+ # should contain header and sequence column
268+ df = pd .read_csv (s )
269+ scores = df .values
270+ # row_num, seq, measurement_name, y, y_mu, y_var
271+ y = scores [:, 3 ]
272+ mus = scores [:, 4 ]
273+ vars = scores [:, 5 ]
274+ return y , mus , vars
0 commit comments