File tree Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Expand file tree Collapse file tree 1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change @@ -6410,11 +6410,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
64106410
64116411 threadgroup_barrier (mem_flags::mem_threadgroup);
64126412
6413- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
64146413 if (sgitg == 0 ) {
6415- for (int i = 0 ; i < n_rows; i++) {
6416- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6417- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6414+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6415+ device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
6416+ device float4 * D4 = (device float4 *) D;
6417+
6418+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6419+ threadgroup float4 * C4 = (threadgroup float4 *) C;
6420+
6421+ int i = 0 ;
6422+ for (; i < n_rows/4 ; i++) {
6423+ *(D4 + i) = *(C4 + i);
6424+ }
6425+
6426+ i *= 4 ;
6427+ for (; i < n_rows; i++) {
6428+ *(D + i) = *(C + i);
64186429 }
64196430 }
64206431 }
You can’t perform that action at this time.
0 commit comments