Skip to content

Commit 86e364b

Browse files
committed
trivial numba, some update prints, playing around with params
1 parent 13656d7 commit 86e364b

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

CCA.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import copy
33
import csv
4+
import numba
45
from PCA import PCA_subcluster
56

67
def CCA_subcluster(R: np.ndarray, N: int, DF: float, kf: float,iter: int,N_subcl_perc: float ,ext_case: int ,tolerance: float=1e-7) -> tuple[bool,bool]:
@@ -134,8 +135,14 @@ def CCA_subcluster(R: np.ndarray, N: int, DF: float, kf: float,iter: int,N_subcl
134135

135136
if int(np.mod(I_total,2)) == 0:
136137
I_total = int(I_total/2)
138+
print("======================")
139+
print(f"{I_total = }")
140+
print("======================")
137141
else:
138142
I_total = int(np.floor(I_total/2)+1)
143+
print("======================")
144+
print(f"{I_total = }")
145+
print("======================")
139146

140147
X = X_next
141148
Y = Y_next
@@ -265,7 +272,8 @@ def CCA_identify_monomers(i_orden: np.ndarray):
265272
ID_mon[j] = i
266273
return ID_mon
267274

268-
def CCA_random_select_list(X1, Y1, Z1, R1, X_cm1, Y_cm1, Z_cm1, X2, Y2, Z2,R2, X_cm2, Y_cm2, Z_cm2, curr_list: np.ndarray, gamma_pc: float, gamma_real: bool, ext_case):
275+
@numba.njit()
276+
def CCA_random_select_list(X1: np.ndarray, Y1: np.ndarray, Z1: np.ndarray, R1: np.ndarray, X_cm1: np.ndarray, Y_cm1: np.ndarray, Z_cm1: np.ndarray, X2: np.ndarray, Y2: np.ndarray, Z2: np.ndarray,R2: np.ndarray, X_cm2: np.ndarray, Y_cm2: np.ndarray, Z_cm2: np.ndarray, curr_list: np.ndarray, gamma_pc: float, gamma_real: bool, ext_case: int):
269277
if gamma_real and ext_case == 1:
270278
for i in range(curr_list.shape[0]-1):
271279
d_i_min = np.sqrt(np.power(X1[i]-X_cm1,2) + np.power(Y1[i]-Y_cm1,2) + np.power(Z1[i]-Z_cm1,2)) - R1[i]
@@ -811,7 +819,8 @@ def CCA_2_sphere_intersec(sphere1: np.ndarray, sphere2: np.ndarray):
811819
return x,y,z, vec0, i_vec,j_vec
812820

813821

814-
def CCA_overlap_check(n1: int, n2: int, X1,X2,Y1,Y2,Z1,Z2,R1,R2):
822+
@numba.jit(nopython=True)
823+
def CCA_overlap_check(n1: int, n2: int, X1: np.ndarray,X2: np.ndarray,Y1: np.ndarray,Y2: np.ndarray,Z1: np.ndarray,Z2: np.ndarray,R1: np.ndarray,R2: np.ndarray):
815824
cov_max = 0
816825

817826
for i in range(n1):
@@ -823,7 +832,8 @@ def CCA_overlap_check(n1: int, n2: int, X1,X2,Y1,Y2,Z1,Z2,R1,R2):
823832
cov_max = c_ij
824833
return cov_max
825834

826-
def CCA_sticking_process_v2(CM2: np.ndarray, vec0: np.ndarray, X2_new,Y2_new,Z2_new, i_vec, j_vec, prev_cand):
835+
@numba.jit(nopython=True)
836+
def CCA_sticking_process_v2(CM2: np.ndarray, vec0: np.ndarray, X2_new: np.ndarray,Y2_new: np.ndarray,Z2_new: np.ndarray, i_vec: np.ndarray, j_vec: np.ndarray, prev_cand: int):
827837
uu = np.random.rand()
828838
theta_a = 2*np.pi * uu
829839

@@ -833,19 +843,23 @@ def CCA_sticking_process_v2(CM2: np.ndarray, vec0: np.ndarray, X2_new,Y2_new,Z2_
833843

834844
v1 = np.array([X2_new[prev_cand]-CM2[0], Y2_new[prev_cand]-CM2[1], Z2_new[prev_cand]-CM2[2]])
835845
v2 = np.array([x-CM2[0], y-CM2[1], z-CM2[2]])
836-
s_vec = np.cross(v1,v2)/np.linalg.norm(np.cross(v1,v2))
846+
s_vec = np.cross(v1,v2)/my_norm(np.cross(v1,v2))
847+
# print(f"{np.linalg.norm(np.cross(v1,v2)) = }")
848+
# print(f"{my_norm(np.cross(v1,v2)) = }")
837849

838-
if np.dot(v1,v2)/np.linalg.norm(np.dot(v1,v2)) > 1 or np.dot(v1,v2)/np.linalg.norm(np.dot(v1,v2)) < -1:
850+
if np.dot(v1,v2)/abs(np.dot(v1,v2)) > 1 or np.dot(v1,v2)/abs(np.dot(v1,v2)) < -1:
839851
angle = np.arccos(1)
840852
else:
841-
angle = np.arccos(np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)))
853+
angle = np.arccos(np.dot(v1,v2)/(my_norm(v1)*my_norm(v2)))
842854

843855

844856
As = np.array([[0, -s_vec[2], s_vec[1]], [s_vec[2], 0, -s_vec[0]], [-s_vec[1], s_vec[0],0]])
845-
rot = np.identity(3) + np.sin(angle)*As + (1-np.cos(angle)) * np.matmul(As,As)
857+
# rot = np.identity(3) + np.sin(angle)*As + (1-np.cos(angle)) * np.matmul(As,As)
858+
rot = np.identity(3) + np.sin(angle)*As + (1-np.cos(angle)) * (As @ As)
846859

847860
for i in range(X2_new.shape[0]):
848-
new_c = np.matmul(rot, np.array([X2_new[i]-CM2[0], Y2_new[i]-CM2[1], Z2_new[i]-CM2[2]]))
861+
# new_c = np.matmul(rot, np.array([X2_new[i]-CM2[0], Y2_new[i]-CM2[1], Z2_new[i]-CM2[2]]))
862+
new_c = rot @ np.array([X2_new[i]-CM2[0], Y2_new[i]-CM2[1], Z2_new[i]-CM2[2]])
849863
X2_new[i] = CM2[0] + new_c[0]
850864
Y2_new[i] = CM2[1] + new_c[1]
851865
Z2_new[i] = CM2[2] + new_c[2]
@@ -868,3 +882,8 @@ def sort_rows(i_orden: np.ndarray):
868882
i_orden[krow,:] = temp
869883

870884
return i_orden
885+
886+
@numba.jit(nopython=True)
887+
def my_norm(a: np.ndarray) -> float:
888+
n = np.sqrt(np.power(a[0],2) + np.power(a[1],2) + np.power(a[2],2))
889+
return n

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# config
99
DF = 2.0
1010
Kf = 1.0
11-
N = 64
11+
N = 1024
1212
R0 = 0.01
1313
SIGMA = 0
1414
EXT_CASE = 0

0 commit comments

Comments
 (0)