|
1 | 1 | # TypeParameterAccessors definitions |
2 | | -using Dagger: Dagger, DArray |
| 2 | +using Dagger: Dagger, Blocks, DArray |
| 3 | +using NDTensors: NDTensors |
3 | 4 | using NDTensors.TypeParameterAccessors: |
4 | | - TypeParameterAccessors, Position, default_type_parameters, parameter, position |
| 5 | + TypeParameterAccessors, Position, default_type_parameters, parameter, position, set_type_parameters |
5 | 6 |
|
6 | 7 | blocktype(darray::DArray) = blocktype(typeof(darray)) |
7 | 8 | blocktype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, blocktype)) |
|
13 | 14 | concattype(darray::DArray) = concattype(typeof(darray)) |
14 | 15 | concattype(darrayT::Type{<:DArray}) = parameter(darrayT, position(darrayT, concattype)) |
15 | 16 |
|
16 | | -function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype)) |
| 17 | +function TypeParameterAccessors.position(::Type{<:DArray}, ::typeof(concattype)) |
17 | 18 | return Position(4) |
18 | 19 | end |
19 | 20 |
|
20 | 21 | ## TODO use autoblock |
21 | 22 | function TypeParameterAccessors.default_type_parameters(::Type{<:DArray}) |
22 | 23 | return (default_type_parameters(AbstractArray)..., Blocks{2}, typeof(cat)) |
23 | 24 | end |
| 25 | + |
| 26 | +## TODO need to make this work. Need to specify |
| 27 | +function NDTensors.set_ndims(type::Type{<:DArray}, param) |
| 28 | + return set_type_parameters(type, (ndims, blocktype), (param, Blocks{param})) |
| 29 | +end |
0 commit comments