Skip to content

Commit f3de4de

Browse files
authored
Support Subarray{<:Adjoint{<:GPUMatrix}} (#108)
* Support Subarray{<:Adjoint{<:GPUMatrix}} * Restore Strided
1 parent cea4178 commit f3de4de

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false})
112112
end
113113
# COV_EXCL_STOP
114114

115+
const SupportedROCMatrix{T} = Union{AnyROCMatrix{T}, SubArray{T, 2, <:AnyROCMatrix{T}}}
116+
115117
function MatrixAlgebraKit._project_hermitian_offdiag!(
116-
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
118+
Au::SupportedROCMatrix, Al::SupportedROCMatrix, Bu::SupportedROCMatrix, Bl::SupportedROCMatrix, ::Val{anti}
117119
) where {anti}
118120
thread_dim = 512
119121
block_dim = cld(size(Au, 2), thread_dim)
120122
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
121123
return nothing
122124
end
123-
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
125+
function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedROCMatrix, B::SupportedROCMatrix, ::Val{anti}) where {anti}
124126
thread_dim = 512
125127
block_dim = cld(size(A, 1), thread_dim)
126128
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false})
136136
end
137137
# COV_EXCL_STOP
138138

139+
const SupportedCuMatrix{T} = Union{AnyCuMatrix{T}, SubArray{T, 2, <:AnyCuMatrix{T}}}
140+
139141
function MatrixAlgebraKit._project_hermitian_offdiag!(
140-
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
142+
Au::SupportedCuMatrix, Al::SupportedCuMatrix, Bu::SupportedCuMatrix, Bl::SupportedCuMatrix, ::Val{anti}
141143
) where {anti}
142144
thread_dim = 512
143145
block_dim = cld(size(Au, 2), thread_dim)
144146
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
145147
return nothing
146148
end
147-
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
149+
function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedCuMatrix, B::SupportedCuMatrix, ::Val{anti}) where {anti}
148150
thread_dim = 512
149151
block_dim = cld(size(A, 1), thread_dim)
150152
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))

0 commit comments

Comments
 (0)