@@ -60,6 +60,10 @@ def log_retry_attempt(retry_state):
6060
6161
6262def _validate_protein_tensor_input (input ):
63+ if isinstance (input , ESMProteinError ):
64+ raise ValueError (
65+ f"Input must be an ESMProteinTensor instance, but received an ESMProteinError instead: { input .error_code } { input .error_msg } "
66+ )
6367 if not isinstance (input , ESMProteinTensor ):
6468 raise ValueError (
6569 f"Input must be an ESMProteinTensor instance, but received { type (input )} instead. "
@@ -71,14 +75,25 @@ class SequenceStructureForgeInferenceClient:
7175 def __init__ (
7276 self ,
7377 url : str = "https://forge.evolutionaryscale.ai" ,
78+ model : str | None = None ,
7479 token : str = "" ,
7580 request_timeout : int | None = None ,
7681 ):
82+ """
83+ Forge client for folding and inverse folding between sequence and structure spaces.
84+
85+ Args:
86+ url: URL of the Forge server.
87+ model: Name of the model to be used for folding / inv folding.
88+ token: API token.
89+ request_timeout: Override the system default request timeout, in seconds.
90+ """
7791 if token == "" :
7892 raise RuntimeError (
7993 "Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
8094 )
8195 self .url = url
96+ self .model = model
8297 self .token = token
8398 self .headers = {"Authorization" : f"Bearer { self .token } " }
8499 self .request_timeout = request_timeout
@@ -89,9 +104,19 @@ def fold(
89104 potential_sequence_of_concern : bool ,
90105 model_name : str | None = None ,
91106 ) -> ESMProtein | ESMProteinError :
107+ """Predict coordinates for a protein sequence.
108+
109+ Args:
110+ sequence: Protein sequence to be folded.
111+ potential_sequence_of_concern: Self disclosed potential_of_concern bit.
112+ This bit is largely ignored by the fold() endpoint.
113+ model_name: Override the client level model name if needed.
114+ """
92115 request = {"sequence" : sequence }
93116 if model_name is not None :
94117 request ["model" ] = model_name
118+ elif self .model is not None :
119+ request ["model" ] = self .model
95120 try :
96121 data = self ._post ("fold" , request , potential_sequence_of_concern )
97122 except ESMProteinError as e :
@@ -109,6 +134,17 @@ def inverse_fold(
109134 potential_sequence_of_concern : bool ,
110135 model_name : str | None = None ,
111136 ) -> ESMProtein | ESMProteinError :
137+ """Generate protein sequence from its structure.
138+
139+ This endpoint is only supported by generative models like ESM3.
140+
141+ Args:
142+ coordinates: Protein sequence coordinates to be inversely folded.
143+ config: Configurations related to inverse folding generation.
144+ potential_sequence_of_concern: Self disclosed potential_of_concern bit.
145+ Requires special permission to use.
146+ model_name: Override the client level model name if needed.
147+ """
112148 inverse_folding_config = {
113149 "invalid_ids" : config .invalid_ids ,
114150 "temperature" : config .temperature ,
@@ -119,6 +155,8 @@ def inverse_fold(
119155 }
120156 if model_name is not None :
121157 request ["model" ] = model_name
158+ elif self .model is not None :
159+ request ["model" ] = self .model
122160 try :
123161 data = self ._post ("inverse_fold" , request , potential_sequence_of_concern )
124162 except ESMProteinError as e :
@@ -208,6 +246,16 @@ def wrapper(instance, *args, **kwargs):
208246
209247 @retry_decorator
210248 def generate (self , input : ProteinType , config : GenerationConfig ) -> ProteinType :
249+ if isinstance (input , ESMProteinError ):
250+ raise ValueError (
251+ f"Input must be an ESMProtein or ESMProteinTensor instance, but received an ESMProteinError instead: { input .error_code } { input .error_msg } "
252+ )
253+ assert isinstance (input , ESMProtein ) or isinstance (input , ESMProteinTensor )
254+ if input .sequence is not None and config .num_steps > len (input .sequence ):
255+ config .num_steps = len (input .sequence )
256+ print (
257+ "Warning: num_steps cannot exceed sequence length. Setting num_steps to sequence length."
258+ )
211259 if isinstance (input , ESMProtein ):
212260 output = self .__generate_protein (input , config )
213261 elif isinstance (input , ESMProteinTensor ):
0 commit comments