@@ -15,12 +15,43 @@ import CUDA
1515parent_array_type (:: Type{<:CUDA.CuArray{T, N, B} where {N}} ) where {T, B} =
1616 CUDA. CuArray{T, N, B} where {N}
1717
18+ # allow on-device use of lazy broadcast objects
19+ parent_array_type (
20+ :: Type{<:CUDA.CuDeviceArray{T, N, A} where {N}} ,
21+ ) where {T, A} = CUDA. CuDeviceArray{T, N, A} where {N}
22+
1823# Ensure that both parent array types have the same memory buffer type.
1924promote_parent_array_type (
2025 :: Type{CUDA.CuArray{T1, N, B} where {N}} ,
2126 :: Type{CUDA.CuArray{T2, N, B} where {N}} ,
2227) where {T1, T2, B} = CUDA. CuArray{promote_type (T1, T2), N, B} where {N}
2328
29+ # allow on-device use of lazy broadcast objects
30+ promote_parent_array_type (
31+ :: Type{CUDA.CuDeviceArray{T1, N, B} where {N}} ,
32+ :: Type{CUDA.CuDeviceArray{T2, N, B} where {N}} ,
33+ ) where {T1, T2, B} = CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N}
34+
35+ # allow on-device use of lazy broadcast objects with different type params
36+ promote_parent_array_type (
37+ :: Type{CUDA.CuDeviceArray{T1, N, B1} where {N}} ,
38+ :: Type{CUDA.CuDeviceArray{T2, N, B2} where {N}} ,
39+ ) where {T1, T2, B1, B2} =
40+ CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
41+
42+ # allow on-device use of lazy broadcast objects with different type params
43+ promote_parent_array_type (
44+ :: Type{CUDA.CuDeviceArray{T1}} ,
45+ :: Type{CUDA.CuDeviceArray{T2, N, B2} where {N}} ,
46+ ) where {T1, T2, B2} =
47+ CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
48+
49+ promote_parent_array_type (
50+ :: Type{CUDA.CuDeviceArray{T1, N, B1} where {N}} ,
51+ :: Type{CUDA.CuDeviceArray{T2} where {N}} ,
52+ ) where {T1, T2, B1} =
53+ CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
54+
2455# Make `similar` accept our special `UnionAll` parent array type for CuArray.
2556Base. similar (
2657 :: Type{CUDA.CuArray{T, N′, B} where {N′}} ,
0 commit comments