Skip to content

Commit 6991ac8

Browse files
committed
Update TypeParameterAccessors to v0.2
1 parent 0efe21e commit 6991ac8

File tree

9 files changed

+32
-25
lines changed

9 files changed

+32
-25
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ StridedViews = "0.2.2, 0.3"
9393
TBLIS = "0.2"
9494
TimerOutputs = "0.5.5"
9595
TupleTools = "1.2.0"
96-
TypeParameterAccessors = "0.1"
96+
TypeParameterAccessors = "0.2"
9797
VectorInterface = "0.4.2, 0.5"
9898
cuTENSOR = "2"
9999
julia = "1.10"

NDTensors/ext/NDTensorsAMDGPUExt/adapt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype
22
using NDTensors.AMDGPUExtensions: AMDGPUExtensions, ROCArrayAdaptor
33
using NDTensors.GPUArraysCoreExtensions: storagemode
44
using NDTensors.TypeParameterAccessors:
5-
default_type_parameter,
6-
set_type_parameter,
7-
set_type_parameters,
8-
type_parameter,
9-
type_parameters
5+
default_type_parameters, set_type_parameters, type_parameters
106
using Adapt: Adapt, adapt
117
using AMDGPU: AMDGPU, ROCArray, ROCVector
128
using Functors: fmap
139

14-
function AMDGPUExtensions.roc(xs; storagemode=default_type_parameter(ROCArray, storagemode))
10+
function AMDGPUExtensions.roc(
11+
xs; storagemode=default_type_parameters(ROCArray, storagemode)
12+
)
1513
return fmap(x -> adapt(ROCArrayAdaptor{storagemode}(), x), xs)
1614
end
1715

NDTensors/ext/NDTensorsCUDAExt/adapt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using NDTensors: NDTensors, EmptyStorage, adapt_storagetype, emptytype
55
using NDTensors.CUDAExtensions: CUDAExtensions, CuArrayAdaptor
66
using NDTensors.GPUArraysCoreExtensions: storagemode
77
using NDTensors.TypeParameterAccessors:
8-
default_type_parameter, set_type_parameters, type_parameters
8+
default_type_parameters, set_type_parameters, type_parameters
99

10-
function CUDAExtensions.cu(xs; storagemode=default_type_parameter(CuArray, storagemode))
10+
function CUDAExtensions.cu(xs; storagemode=default_type_parameters(CuArray, storagemode))
1111
return fmap(x -> adapt(CuArrayAdaptor{storagemode}(), x), xs)
1212
end
1313

NDTensors/src/abstractarray/generic_array_constructors.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using TypeParameterAccessors:
2-
unwrap_array_type, specify_default_type_parameters, type_parameter
2+
unwrap_array_type,
3+
specify_default_type_parameters,
4+
specify_type_parameters,
5+
type_parameters
36

47
# Convert to Array, avoiding copying if possible
58
array(a::AbstractArray) = a
@@ -8,9 +11,9 @@ vector(a::AbstractVector) = a
811

912
## Warning to use these functions it is necessary to define `TypeParameterAccessors.position(::Type{<:YourArrayType}, ::typeof(ndims)))`
1013
# Implementation, catches if `ndims(arraytype) != length(dims)`.
11-
## TODO convert ndims to `type_parameter(::, typeof(ndims))`
14+
## TODO convert ndims to `type_parameters(::, typeof(ndims))`
1215
function generic_randn(arraytype::Type{<:AbstractArray}, dims...; rng=Random.default_rng())
13-
arraytype_specified = specify_type_parameter(
16+
arraytype_specified = specify_type_parameters(
1417
unwrap_array_type(arraytype), ndims, length(dims)
1518
)
1619
arraytype_specified = specify_default_type_parameters(arraytype_specified)
@@ -27,7 +30,7 @@ end
2730

2831
# Implementation, catches if `ndims(arraytype) != length(dims)`.
2932
function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
30-
arraytype_specified = specify_type_parameter(
33+
arraytype_specified = specify_type_parameters(
3134
unwrap_array_type(arraytype), ndims, length(dims)
3235
)
3336
arraytype_specified = specify_default_type_parameters(arraytype_specified)

NDTensors/src/adapt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ double_precision(x) = fmap(x -> adapt(double_precision(eltype(x)), x), x)
2727
# Used to adapt `EmptyStorage` types
2828
#
2929

30-
using TypeParameterAccessors: specify_type_parameter, specify_type_parameters
30+
using TypeParameterAccessors: specify_type_parameters
3131
function adapt_storagetype(to::Type{<:AbstractVector}, x::Type{<:TensorStorage})
32-
return set_datatype(x, specify_type_parameter(to, eltype, eltype(x)))
32+
return set_datatype(x, specify_type_parameters(to, eltype, eltype(x)))
3333
end
3434

3535
function adapt_storagetype(to::Type{<:AbstractArray}, x::Type{<:TensorStorage})
36-
return set_datatype(x, specify_type_parameter(to, (ndims, eltype), (1, eltype(x))))
36+
return set_datatype(x, specify_type_parameters(to, (ndims, eltype), (1, eltype(x))))
3737
end

NDTensors/src/dense/generic_array_constructors.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
using TypeParameterAccessors:
2-
default_type_parameter,
2+
default_type_parameters,
33
parenttype,
44
set_eltype,
55
specify_default_type_parameters,
6-
type_parameter
6+
specify_type_parameters,
7+
type_parameters
78
##TODO replace randn in ITensors with generic_randn
89
## and replace zeros with generic_zeros
910

@@ -12,7 +13,9 @@ using TypeParameterAccessors:
1213

1314
function generic_randn(StoreT::Type{<:Dense}, dims::Integer; rng=Random.default_rng())
1415
StoreT = specify_default_type_parameters(StoreT)
15-
DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT))
16+
DataT = specify_type_parameters(
17+
type_parameters(StoreT, parenttype), eltype, eltype(StoreT)
18+
)
1619
@assert eltype(StoreT) == eltype(DataT)
1720

1821
data = generic_randn(DataT, dims; rng=rng)
@@ -22,7 +25,9 @@ end
2225

2326
function generic_zeros(StoreT::Type{<:Dense}, dims::Integer)
2427
StoreT = specify_default_type_parameters(StoreT)
25-
DataT = specify_type_parameter(type_parameter(StoreT, parenttype), eltype, eltype(StoreT))
28+
DataT = specify_type_parameters(
29+
type_parameters(StoreT, parenttype), eltype, eltype(StoreT)
30+
)
2631
@assert eltype(StoreT) == eltype(DataT)
2732

2833
data = generic_zeros(DataT, dims)

NDTensors/src/lib/Expose/src/exposed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using TypeParameterAccessors:
2-
TypeParameterAccessors, unwrap_array_type, parameter, parenttype, type_parameter
2+
TypeParameterAccessors, unwrap_array_type, parenttype, type_parameters
33
struct Exposed{Unwrapped,Object}
44
object::Object
55
end
@@ -9,7 +9,7 @@ expose(object) = Exposed{unwrap_array_type(object),typeof(object)}(object)
99
unexpose(E::Exposed) = E.object
1010

1111
## TODO remove TypeParameterAccessors when SetParameters is removed
12-
TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = parameter(type, parenttype)
12+
TypeParameterAccessors.parenttype(type::Type{<:Exposed}) = type_parameters(type, parenttype)
1313
function TypeParameterAccessors.position(::Type{<:Exposed}, ::typeof(parenttype))
1414
return TypeParameterAccessors.Position(1)
1515
end

NDTensors/src/lib/GPUArraysCoreExtensions/src/gpuarrayscore.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
using ..Expose: Exposed, unexpose
2-
using TypeParameterAccessors: TypeParameterAccessors, type_parameter, set_type_parameter
2+
using TypeParameterAccessors: TypeParameterAccessors, type_parameters, set_type_parameters
33

44
function storagemode(object)
55
return storagemode(typeof(object))
66
end
77
function storagemode(type::Type)
8-
return type_parameter(type, storagemode)
8+
return type_parameters(type, storagemode)
99
end
1010

1111
function set_storagemode(type::Type, param)
12-
return set_type_parameter(type, storagemode, param)
12+
return set_type_parameters(type, storagemode, param)
1313
end
1414

1515
function cpu end

NDTensors/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2020
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
2121
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2324
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2425

2526
[compat]

0 commit comments

Comments
 (0)