@@ -114,56 +114,11 @@ end
114
114
115
115
# # high-level functionality
116
116
117
- function transpose_blocks! (
118
- state, odata:: AbstractArray{T} , idata, :: Val{SHMEM} , :: Val{TDIM} , :: Val{BLOCK_ROWS} , :: Val{NROW}
119
- ) where {T, SHMEM, TDIM, BLOCK_ROWS, NROW}
120
-
121
- tile = @LocalMemory (state, T, SHMEM)
122
- bidx_x = blockidx_x (state) - 1
123
- bidx_y = blockidx_y (state) - 1
124
- tidx_x = threadidx_x (state) - 1
125
- tidx_y = threadidx_y (state) - 1
126
-
127
- x = bidx_x * TDIM + tidx_x + 1
128
- y = bidx_y * TDIM + tidx_y + 1
129
- dims = size (idata)
130
-
131
- (x <= dims[2 ] && (y + (BLOCK_ROWS * 3 )) <= dims[1 ]) || return
132
-
133
- for j = 0 : 3
134
- j0 = j * BLOCK_ROWS
135
- @inbounds tile[tidx_x + 1 , tidx_y + j0 + 1 ] = idata[y + j0, x]
136
- end
137
-
138
- synchronize_threads (state)
139
- for j = 0 : 3
140
- j0 = j * BLOCK_ROWS
141
- @inbounds odata[x, y + j0] = tile[tidx_x + 1 , tidx_y + j0 + 1 ]
142
- end
143
-
144
- return
145
- end
146
-
147
117
function LinearAlgebra. transpose! (At:: AbstractGPUArray{T, 2} , A:: AbstractGPUArray{T, 2} ) where T
148
- if size (A, 1 ) == size (A, 2 ) && all (x-> x % 32 == 0 , size (A))
149
- outsize = size (At)
150
- TDIM = 32 ; BLOCK_ROWS = 8
151
- nrows = TDIM ÷ BLOCK_ROWS
152
- shmemdim = (TDIM, (TDIM + 1 ))
153
- static_params = map (x-> Val (x), (shmemdim, TDIM, BLOCK_ROWS, nrows))
154
- args = (At, A, static_params... )
155
-
156
- griddim = ceil .(Int, size (A) ./ (TDIM, TDIM))
157
- blockdim = (TDIM, BLOCK_ROWS)
158
- # optimized version for 32x & square dimensions
159
- gpu_call (transpose_blocks!, At, args, (griddim, blockdim))
160
- else
161
- # simple fallback
162
- gpu_call (At, (At, A)) do state, At, A
163
- idx = @cartesianidx A state
164
- @inbounds At[idx[2 ], idx[1 ]] = A[idx[1 ], idx[2 ]]
165
- return
166
- end
118
+ gpu_call (At, (At, A)) do state, At, A
119
+ idx = @cartesianidx A state
120
+ @inbounds At[idx[2 ], idx[1 ]] = A[idx[1 ], idx[2 ]]
121
+ return
167
122
end
168
123
At
169
124
end
0 commit comments