@@ -186,9 +186,7 @@ def __init__(
186186
187187 chain_mapping = [pl .read_csv (path ) for path in chain_mapping_paths ]
188188 # Increment chain cluster IDs to avoid overlap
189- chain_cluster_nums = [
190- mapping .get_column ("cluster_id" ).max () for mapping in chain_mapping
191- ]
189+ chain_cluster_nums = [mapping .get_column ("cluster_id" ).max () for mapping in chain_mapping ]
192190 for i in range (1 , len (chain_mapping )):
193191 chain_mapping [i ] = chain_mapping [i ].with_columns (
194192 (pl .col ("cluster_id" ) + sum (chain_cluster_nums [:i ])).alias ("cluster_id" )
@@ -218,20 +216,6 @@ def __init__(
218216 compute_interface_weights (interface_mapping , self .alphas , self .betas ["interface" ]),
219217 )
220218
221- # Add additional information to the cluster IDs
222- chain_mapping = chain_mapping .with_columns (
223- (pl .col ("molecule_id" ) + "-" + pl .col ("cluster_id" ).cast (pl .String )).alias ("cluster_id" )
224- )
225- interface_mapping = interface_mapping .with_columns (
226- (
227- pl .col ("interface_molecule_id_1" )
228- + "-"
229- + pl .col ("interface_molecule_id_2" )
230- + "-"
231- + pl .col ("interface_cluster_id" ).cast (pl .String )
232- ).alias ("interface_cluster_id" )
233- )
234-
235219 # Concatenate chain and interface mappings
236220 chain_mapping = chain_mapping .with_columns (
237221 [
@@ -247,7 +231,9 @@ def __init__(
247231 [
248232 pl .col ("interface_chain_id_1" ).alias ("chain_id_1" ),
249233 pl .col ("interface_chain_id_2" ).alias ("chain_id_2" ),
250- pl .col ("interface_cluster_id" ).alias ("cluster_id" ),
234+ (
235+ pl .col ("interface_cluster_id" ) + chain_mapping .get_column ("cluster_id" ).max ()
236+ ).alias ("cluster_id" ),
251237 ]
252238 )
253239 interface_mapping = interface_mapping .select (
0 commit comments