Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NDTensors/ext/NDTensorsDaggerExt/NDTensorsDaggerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
module NDTensorsDaggerExt
include("set_types.jl")
include("similar.jl")
end
34 changes: 34 additions & 0 deletions NDTensors/ext/NDTensorsDaggerExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# TypeParameterAccessors definitions
using Dagger: Dagger, Blocks, DArray
using NDTensors: NDTensors
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors,
Position,
default_type_parameters,
parameter,
position,
set_type_parameters

blocktype(darray::DArray) = blocktype(typeof(darray))
blocktype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, blocktype))

function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(blocktype))
return Position(3)
end

concattype(darray::DArray) = concattype(typeof(darray))
concattype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, concattype))

function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype))
return Position(4)
end

## TODO use autoblock
function TypeParameterAccessors.default_type_parameters(::Type{<:DArray})
return (default_type_parameters(AbstractArray)..., Blocks{2}, typeof(cat))
end

## TODO need to make this work. Need to specify
function NDTensors.set_ndims(type::Type{<:DArray}, param)
return set_type_parameters(type, (ndims, blocktype), (param, Blocks{param}))
end
12 changes: 12 additions & 0 deletions NDTensors/ext/NDTensorsDaggerExt/similar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using NDTensors: NDTensors
using NDTensors.Expose: Exposed, unexpose
using Dagger: DArray

function NDTensors.similar(E::Exposed{<:DArray})
A = unexpose(E)
return Base.similar(A)
end

function NDTensors.similar(E::Exposed{<:DArray}, eltype::Type)
return Base.similar(unexpose(E), eltype)
end
5 changes: 5 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ end
# NDTensors.similar
similar(array::AbstractArray, dims::Tuple) = NDTensors.similar(typeof(array), dims)

using NDTensors.Expose: Exposed, unexpose
function similar(E::Exposed, eltype::Type)
return similar(unexpose(E), eltype)
end
# Use the `size` to determine the dimensions
# NDTensors.similar
function similar(array::AbstractArray, eltype::Type)
Expand All @@ -95,6 +99,7 @@ end

# Use the `size` to determine the dimensions
# NDTensors.similar
similar(E::Exposed) = similar(unexpose(E))
similar(array::AbstractArray) = NDTensors.similar(typeof(array), size(array))

## similartype
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/tensorstorage/similar.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# NDTensors.similar
similar(storage::TensorStorage) = setdata(storage, NDTensors.similar(data(storage)))
similar(storage::TensorStorage) = setdata(storage, NDTensors.similar(expose(data(storage))))

# NDTensors.similar
function similar(storage::TensorStorage, eltype::Type)
return setdata(storage, NDTensors.similar(data(storage), eltype))
return setdata(storage, NDTensors.similar(expose(data(storage)), eltype))
end

# NDTensors.similar
Expand Down Expand Up @@ -45,6 +45,7 @@ function similar(storagetype::Type{<:TensorStorage}, dims::Dims)
# TODO: Don't convert to an `AbstractVector` with `prod`, once we support
# more general data types.
# return setdata(storagetype, NDTensors.similar(datatype(storagetype), dims))
## TODO use expose here to work for DArray
return setdata(storagetype, NDTensors.similar(datatype(storagetype), prod(dims)))
end

Expand Down
Loading