Skip to content

Commit aab2e4a

Browse files
committed
Updates to get contract working
1 parent b6da28b commit aab2e4a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

NDTensors/ext/NDTensorsDaggerExt/set_types.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# TypeParameterAccessors definitions
2-
using Dagger: Dagger, DArray
2+
using Dagger: Dagger, Blocks, DArray
3+
using NDTensors: NDTensors
34
using NDTensors.TypeParameterAccessors:
4-
TypeParameterAccessors, Position, default_type_parameters, parameter, position
5+
TypeParameterAccessors, Position, default_type_parameters, parameter, position, set_type_parameters
56

67
blocktype(darray::DArray) = blocktype(typeof(darray))
78
blocktype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, blocktype))
@@ -13,11 +14,16 @@ end
1314
concattype(darray::DArray) = concattype(typeof(darray))
1415
concattype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, concattype))
1516

16-
function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype))
17+
function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype))
1718
return Position(4)
1819
end
1920

2021
## TODO use autoblock
2122
function TypeParameterAccessors.default_type_parameters(::Type{<:DArray})
2223
return (default_type_parameters(AbstractArray)..., Blocks{2}, typeof(cat))
2324
end
25+
26+
## TODO need to make this work. Need to specify
27+
function NDTensors.set_ndims(type::Type{<:DArray}, param)
28+
return set_type_parameters(type, (ndims, blocktype), (param, Blocks{param}))
29+
end

NDTensors/src/tensorstorage/similar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function similar(storagetype::Type{<:TensorStorage}, dims::Dims)
4545
# TODO: Don't convert to an `AbstractVector` with `prod`, once we support
4646
# more general data types.
4747
# return setdata(storagetype, NDTensors.similar(datatype(storagetype), dims))
48+
## TODO use expose here to work for DArray
4849
return setdata(storagetype, NDTensors.similar(datatype(storagetype), prod(dims)))
4950
end
5051

0 commit comments

Comments
 (0)