11#include <immintrin.h>
22#include <omp.h>
33#include <stdio.h>
4+ #include "../../include/Core/matrix_multiply.h"
45
56/**
67 * @brief Performs a matrix multiplication using SIMD instructions (AVX).
1819void matrix_multiply_simd (const float * A , const float * B_T , float * C ,
1920 int M , int N , int K , float scale )
2021{
22+ if (A == NULL || B_T == NULL || C == NULL )
23+ {
24+ fprintf (stderr , "Error: Null pointer passed to matrix_multiply_simd.\n" );
25+ return ;
26+ }
27+
28+ if (M <= 0 || N <= 0 || K <= 0 )
29+ {
30+ fprintf (stderr , "Error: Invalid matrix dimensions passed to matrix_multiply_simd.\n" );
31+ return ;
32+ }
33+
2134#pragma omp parallel for collapse(2)
2235 for (int i = 0 ; i < M ; i ++ )
2336 {
@@ -29,7 +42,7 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
2942 for (k = 0 ; k <= K - 8 ; k += 8 )
3043 {
3144 __m256 a = _mm256_loadu_ps (& A [i * K + k ]);
32- __m256 b = _mm256_loadu_ps (& B_T [j * K + k ]); // Access row in transposed B
45+ __m256 b = _mm256_loadu_ps (& B_T [j * K + k ]);
3346 sum = _mm256_add_ps (sum , _mm256_mul_ps (a , b ));
3447 }
3548
@@ -62,8 +75,70 @@ void matrix_multiply_simd(const float *A, const float *B_T, float *C,
6275 */
6376void transpose_matrix (const float * B , float * B_T , int K , int N )
6477{
78+ if (B == NULL || B_T == NULL )
79+ {
80+ fprintf (stderr , "Error: Null pointer passed to transpose_matrix.\n" );
81+ return ;
82+ }
83+
84+ if (K <= 0 || N <= 0 )
85+ {
86+ fprintf (stderr , "Error: Invalid matrix dimensions passed to transpose_matrix.\n" );
87+ return ;
88+ }
89+
6590#pragma omp parallel for collapse(2)
6691 for (int i = 0 ; i < K ; ++ i )
6792 for (int j = 0 ; j < N ; ++ j )
6893 B_T [j * K + i ] = B [i * N + j ];
6994}
95+
96+ /**
97+ * @brief Performs a matrix multiplication, checking if the second matrix is already transposed.
98+ *
99+ * C = A * B * scale
100+ *
101+ * @param A Pointer to the first matrix (M x K).
102+ * @param B Pointer to the second matrix (K x N).
103+ * @param C Pointer to the result matrix (M x N).
104+ * @param M Number of rows in matrix A.
105+ * @param N Number of columns in matrix B.
106+ * @param K Number of columns in matrix A and rows in matrix B.
107+ * @param scale Scaling factor to apply to the result.
108+ * @param is_transposed Flag indicating if B is already transposed (1 if true, 0 otherwise).
109+ */
110+ void matrix_multiply (const float * A , const float * B , float * C ,
111+ int M , int N , int K , float scale , int is_transposed )
112+ {
113+ if (A == NULL || B == NULL || C == NULL )
114+ {
115+ fprintf (stderr , "Error: Null pointer passed to matrix_multiply.\n" );
116+ return ;
117+ }
118+
119+ if (M <= 0 || N <= 0 || K <= 0 )
120+ {
121+ fprintf (stderr , "Error: Invalid matrix dimensions passed to matrix_multiply.\n" );
122+ return ;
123+ }
124+
125+ if (is_transposed )
126+ {
127+ matrix_multiply_simd (A , B , C , M , N , K , scale );
128+ }
129+ else
130+ {
131+ float * B_T = (float * )aligned_alloc (32 , N * K * sizeof (float ));
132+ if (B_T == NULL )
133+ {
134+ fprintf (stderr , "Error: Memory allocation failed for transposed matrix in matrix_multiply.\n" );
135+ return ;
136+ }
137+
138+ transpose_matrix (B , B_T , K , N );
139+
140+ matrix_multiply_simd (A , B_T , C , M , N , K , scale );
141+
142+ free (B_T );
143+ }
144+ }
0 commit comments