Skip to content

Commit 113ad3c

Browse files
committed
Small PR to start getting Dagger to work in the Dense framework
1 parent 9e3efd2 commit 113ad3c

File tree

5 files changed

+45
-2
lines changed

5 files changed

+45
-2
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
module NDTensorsDaggerExt
2+
include("set_types.jl")
3+
include("similar.jl")
4+
end
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# TypeParameterAccessors definitions
2+
using Dagger: Dagger, DArray
3+
using NDTensors.TypeParameterAccessors:
4+
TypeParameterAccessors, Position, default_type_parameters, parameter, position
5+
6+
blocktype(darray::DArray) = blocktype(typeof(darray))
7+
blocktype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, blocktype))
8+
9+
function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(blocktype))
10+
return Position(3)
11+
end
12+
13+
concattype(darray::DArray) = concattype(typeof(darray))
14+
concattype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, concattype))
15+
16+
function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype))
17+
return Position(4)
18+
end
19+
20+
## TODO use autoblock
21+
function TypeParameterAccessors.default_type_parameters(::Type{<:DArray})
22+
return (default_type_parameters(AbstractArray)..., Blocks{2}, typeof(cat))
23+
end
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using NDTensors: NDTensors
2+
using NDTensors.Expose: Exposed, unexpose
3+
using Dagger: DArray
4+
5+
function NDTensors.similar(E::Exposed{<:DArray})
6+
A = unexpose(E)
7+
return Base.similar(A)
8+
end
9+
10+
function similar(E::Exposed{<:DArray}, eltype::Type)
11+
return Base.similar(unexpose(E), eltype)
12+
end

NDTensors/src/abstractarray/similar.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ end
8787
# NDTensors.similar
8888
similar(array::AbstractArray, dims::Tuple) = NDTensors.similar(typeof(array), dims)
8989

90+
function similar(E::Exposed, eltype::Type)
91+
return similar(unexpose(E), eltype)
92+
end
9093
# Use the `size` to determine the dimensions
9194
# NDTensors.similar
9295
function similar(array::AbstractArray, eltype::Type)
@@ -95,6 +98,7 @@ end
9598

9699
# Use the `size` to determine the dimensions
97100
# NDTensors.similar
101+
similar(E::Exposed) = similar(unexpose(E))
98102
similar(array::AbstractArray) = NDTensors.similar(typeof(array), size(array))
99103

100104
## similartype

NDTensors/src/tensorstorage/similar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# NDTensors.similar
2-
similar(storage::TensorStorage) = setdata(storage, NDTensors.similar(data(storage)))
2+
similar(storage::TensorStorage) = setdata(storage, NDTensors.similar(expose(data(storage))))
33

44
# NDTensors.similar
55
function similar(storage::TensorStorage, eltype::Type)
6-
return setdata(storage, NDTensors.similar(data(storage), eltype))
6+
return setdata(storage, NDTensors.similar(expose(data(storage)), eltype))
77
end
88

99
# NDTensors.similar

0 commit comments

Comments
 (0)