Skip to content

Commit 14549eb

Browse files
committed
added numba fast matmul
1 parent d0e9458 commit 14549eb

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

openptv_python/ray_tracing.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Tuple
33

44
import numpy as np
5-
from numba import njit
5+
from numba import float64, int64, njit, prange
66

77
from .calibration import Calibration
88
from .parameters import MultimediaPar
@@ -306,3 +306,36 @@ def fast_ray_tracing(
306306
# out = vec_add(tmp1, tmp2)
307307

308308
# return X, out
309+
310+
# import numpy as np
311+
# import time
312+
313+
@njit(float64[:, :](float64[:, :], float64[:, :], float64[:, :], int64, int64, int64), parallel=True)
314+
def matmul_numba_optimized(a, b, c, m, n, k):
315+
for i in prange(m):
316+
for j in range(k):
317+
temp = 0.0
318+
for ll in range(n):
319+
temp += b[i, ll] * c[ll, j]
320+
a[i, j] = temp
321+
return a
322+
323+
# # Define the same inputs as in the C test
324+
# b = np.array([
325+
# [1.0, 2.0, 3.0],
326+
# [4.0, 5.0, 6.0]
327+
# ], dtype=np.float64)
328+
# c = np.array([
329+
# [1.0, 0.0],
330+
# [0.0, 1.0],
331+
# [1.0, 1.0]
332+
# ], dtype=np.float64)
333+
# a = np.zeros((2, 2), dtype=np.float64)
334+
335+
# start_time = time.time()
336+
# matmul_numba_optimized(a, b, c, 2, 3, 2)
337+
# end_time = time.time()
338+
339+
# print("Optimized Python Numba Time:", end_time - start_time, "seconds")
340+
# print("Result from Python function:")
341+
# print(a)

0 commit comments

Comments
 (0)