99from scripts .cluster_pdb_mmcifs import CLUSTERING_MOLECULE_TYPE
1010
1111
12- @typecheck
1312def get_chain_count (molecule_type : CLUSTERING_MOLECULE_TYPE ) -> Tuple [int , int , int ]:
1413 """
1514 Returns the number of protein, nucleic acid, and ligand chains in a
@@ -31,7 +30,6 @@ def get_chain_count(molecule_type: CLUSTERING_MOLECULE_TYPE) -> Tuple[int, int,
3130 raise ValueError (f"Unknown molecule type: { molecule_type } " )
3231
3332
34- @typecheck
3533def calculate_weight (
3634 alphas : Dict [str , float ],
3735 beta : float ,
@@ -84,7 +82,7 @@ def get_interface_weight(
8482def get_cluster_sizes (
8583 mapping : pl .DataFrame ,
8684 cluster_id_col : str ,
87- ) -> Dict [str , int ]:
85+ ) -> Dict [int , int ]:
8886 """
8987 Returns a dictionary where keys are cluster IDs and values are the number
9088 of chains/interfaces in the cluster.
@@ -186,26 +184,18 @@ def __init__(
186184 if not isinstance (chain_mapping_paths , list ):
187185 chain_mapping_paths = [chain_mapping_paths ]
188186
189- chain_mapping = []
190- for path in chain_mapping_paths :
191- mapping = pl .read_csv (path )
192- mapping = mapping .with_columns (
193- (pl .col ("molecule_id" ) + "-" + pl .col ("cluster_id" ).cast (pl .String )).alias (
194- "cluster_id"
195- )
187+ chain_mapping = [pl .read_csv (path ) for path in chain_mapping_paths ]
188+ # Increment chain cluster IDs to avoid overlap
189+ chain_cluster_nums = [
190+ mapping .get_column ("cluster_id" ).max () for mapping in chain_mapping
191+ ]
192+ for i in range (1 , len (chain_mapping )):
193+ chain_mapping [i ] = chain_mapping [i ].with_columns (
194+ (pl .col ("cluster_id" ) + sum (chain_cluster_nums [:i ])).alias ("cluster_id" )
196195 )
197- chain_mapping . append ( mapping )
196+
198197 chain_mapping = pl .concat (chain_mapping )
199198 interface_mapping = pl .read_csv (interface_mapping_path )
200- interface_mapping = interface_mapping .with_columns (
201- (
202- pl .col ("interface_molecule_id_1" )
203- + "-"
204- + pl .col ("interface_molecule_id_2" )
205- + "-"
206- + pl .col ("interface_cluster_id" ).cast (pl .String )
207- ).alias ("interface_cluster_id" )
208- )
209199
210200 # Filter out unwanted PDB IDs
211201 if len (pdb_ids_to_skip ) > 0 :
@@ -228,6 +218,20 @@ def __init__(
228218 compute_interface_weights (interface_mapping , self .alphas , self .betas ["interface" ]),
229219 )
230220
221+ # Add additional information to the cluster IDs
222+ chain_mapping = chain_mapping .with_columns (
223+ (pl .col ("molecule_id" ) + "-" + pl .col ("cluster_id" ).cast (pl .String )).alias ("cluster_id" )
224+ )
225+ interface_mapping = interface_mapping .with_columns (
226+ (
227+ pl .col ("interface_molecule_id_1" )
228+ + "-"
229+ + pl .col ("interface_molecule_id_2" )
230+ + "-"
231+ + pl .col ("interface_cluster_id" ).cast (pl .String )
232+ ).alias ("interface_cluster_id" )
233+ )
234+
231235 # Concatenate chain and interface mappings
232236 chain_mapping = chain_mapping .with_columns (
233237 [
0 commit comments