1- const ROCTensorMap{T,S, N₁,N₂} = TensorMap{T,S, N₁,N₂, ROCVector{T,AMDGPU. Mem. HIPBuffer}}
1+ const ROCTensorMap{T, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, ROCVector{T, AMDGPU. Mem. HIPBuffer}}
22const ROCTensor{T, S, N} = ROCTensorMap{T, S, N, 0 }
33
4- const AdjointROCTensorMap{T,S, N₁,N₂} = AdjointTensorMap{T,S, N₁,N₂,ROCTensorMap{T,S, N₁,N₂}}
4+ const AdjointROCTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, ROCTensorMap{T, S, N₁, N₂}}
55
66function TensorKit. tensormaptype (S:: Type{<:IndexSpace} , N₁, N₂, TorA:: Type{<:StridedROCArray} )
77 if TorA <: ROCArray
8- return TensorMap{eltype (TorA),S, N₁,N₂,ROCVector{eltype (TorA), AMDGPU. Mem. HIPBuffer}}
8+ return TensorMap{eltype (TorA), S, N₁, N₂, ROCVector{eltype (TorA), AMDGPU. Mem. HIPBuffer}}
99 else
1010 throw (ArgumentError (" argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:ROCVector{<:Number}`" ))
1111 end
1212end
1313
1414function ROCTensorMap {T} (:: UndefInitializer , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
15- return ROCTensorMap {T,S, N₁,N₂} (undef, V)
15+ return ROCTensorMap {T, S, N₁, N₂} (undef, V)
1616end
1717
18- function ROCTensorMap {T} (:: UndefInitializer , codomain:: TensorSpace{S} ,
19- domain:: TensorSpace{S} ) where {T,S}
18+ function ROCTensorMap {T} (
19+ :: UndefInitializer , codomain:: TensorSpace{S} ,
20+ domain:: TensorSpace{S}
21+ ) where {T, S}
2022 return ROCTensorMap {T} (undef, codomain ← domain)
2123end
22- function ROCTensor {T} (:: UndefInitializer , V:: TensorSpace{S} ) where {T,S}
24+ function ROCTensor {T} (:: UndefInitializer , V:: TensorSpace{S} ) where {T, S}
2325 return ROCTensorMap {T} (undef, V ← one (V))
2426end
2527# constructor starting from block data
@@ -42,8 +44,10 @@ Construct a `ROCTensorMap` by explicitly specifying its block data.
4244Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
4345using the syntax `codomain ← domain` or `domain → codomain`.
4446"""
45- function ROCTensorMap (data:: AbstractDict{<:Sector,<:ROCArray} ,
46- V:: TensorMapSpace{S,N₁,N₂} ) where {S,N₁,N₂}
47+ function ROCTensorMap (
48+ data:: AbstractDict{<:Sector, <:ROCArray} ,
49+ V:: TensorMapSpace{S, N₁, N₂}
50+ ) where {S, N₁, N₂}
4751 T = eltype (valtype (data))
4852 t = ROCTensorMap {T} (undef, V)
4953 for (c, b) in blocks (t)
@@ -59,12 +63,16 @@ function ROCTensorMap(data::AbstractDict{<:Sector,<:ROCArray},
5963 end
6064 return t
6165end
62- function ROCTensorMap {T} (data:: DenseVector{T} , codomain:: TensorSpace{S} ,
63- domain:: TensorSpace{S} ) where {T,S}
66+ function ROCTensorMap {T} (
67+ data:: DenseVector{T} , codomain:: TensorSpace{S} ,
68+ domain:: TensorSpace{S}
69+ ) where {T, S}
6470 return ROCTensorMap (data, codomain ← domain)
6571end
66- function ROCTensorMap (data:: AbstractDict{<:Sector,<:ROCMatrix} , codom:: TensorSpace{S} ,
67- dom:: TensorSpace{S} ) where {S}
72+ function ROCTensorMap (
73+ data:: AbstractDict{<:Sector, <:ROCMatrix} , codom:: TensorSpace{S} ,
74+ dom:: TensorSpace{S}
75+ ) where {S}
6876 return ROCTensorMap (data, codom ← dom)
6977end
7078
7482
7583for (fname, felt) in ((:zeros , :zero ), (:ones , :one ))
7684 @eval begin
77- function AMDGPU. $fname (codomain:: TensorSpace{S} ,
78- domain:: TensorSpace{S} = one (codomain)) where {S<: IndexSpace }
85+ function AMDGPU. $fname (
86+ codomain:: TensorSpace{S} ,
87+ domain:: TensorSpace{S} = one (codomain)
88+ ) where {S <: IndexSpace }
7989 return AMDGPU.$ fname (codomain ← domain)
8090 end
81- function AMDGPU. $fname (:: Type{T} , codomain:: TensorSpace{S} ,
82- domain:: TensorSpace{S} = one (codomain)) where {T,S<: IndexSpace }
91+ function AMDGPU. $fname (
92+ :: Type{T} , codomain:: TensorSpace{S} ,
93+ domain:: TensorSpace{S} = one (codomain)
94+ ) where {T, S <: IndexSpace }
8395 return AMDGPU.$ fname (T, codomain ← domain)
8496 end
8597 AMDGPU.$ fname (V:: TensorMapSpace ) = AMDGPU.$ fname (Float64, V)
@@ -95,17 +107,23 @@ for randfun in (:rocrand, :rocrandn)
95107 randfun! = Symbol (randfun, :! )
96108 @eval begin
97109 # converting `codomain` and `domain` into `HomSpace`
98- function $randfun (codomain:: TensorSpace{S} ,
99- domain:: TensorSpace{S} ) where {S<: IndexSpace }
110+ function $randfun (
111+ codomain:: TensorSpace{S} ,
112+ domain:: TensorSpace{S}
113+ ) where {S <: IndexSpace }
100114 return $ randfun (codomain ← domain)
101115 end
102- function $randfun (:: Type{T} , codomain:: TensorSpace{S} ,
103- domain:: TensorSpace{S} ) where {T,S<: IndexSpace }
116+ function $randfun (
117+ :: Type{T} , codomain:: TensorSpace{S} ,
118+ domain:: TensorSpace{S}
119+ ) where {T, S <: IndexSpace }
104120 return $ randfun (T, codomain ← domain)
105121 end
106- function $randfun (rng:: Random.AbstractRNG , :: Type{T} ,
107- codomain:: TensorSpace{S} ,
108- domain:: TensorSpace{S} ) where {T,S<: IndexSpace }
122+ function $randfun (
123+ rng:: Random.AbstractRNG , :: Type{T} ,
124+ codomain:: TensorSpace{S} ,
125+ domain:: TensorSpace{S}
126+ ) where {T, S <: IndexSpace }
109127 return $ randfun (rng, T, codomain ← domain)
110128 end
111129
@@ -114,8 +132,10 @@ for randfun in (:rocrand, :rocrandn)
114132 function $randfun (:: Type{T} , codomain:: TensorSpace ) where {T}
115133 return $ randfun (T, codomain ← one (codomain))
116134 end
117- function $randfun (rng:: Random.AbstractRNG , :: Type{T} ,
118- codomain:: TensorSpace ) where {T}
135+ function $randfun (
136+ rng:: Random.AbstractRNG , :: Type{T} ,
137+ codomain:: TensorSpace
138+ ) where {T}
119139 return $ randfun (rng, T, codomain ← one (domain))
120140 end
121141
@@ -131,8 +151,10 @@ for randfun in (:rocrand, :rocrandn)
131151 end
132152
133153 # implementation
134- function $randfun (rng:: Random.AbstractRNG , :: Type{T} ,
135- V:: TensorMapSpace ) where {T}
154+ function $randfun (
155+ rng:: Random.AbstractRNG , :: Type{T} ,
156+ V:: TensorMapSpace
157+ ) where {T}
136158 t = ROCTensorMap {T} (undef, V)
137159 $ randfun! (rng, t)
138160 return t
142164
143165# converters
144166# ----------
145- function Base. convert (:: Type{ROCTensorMap} , d:: Dict{Symbol,Any} )
167+ function Base. convert (:: Type{ROCTensorMap} , d:: Dict{Symbol, Any} )
146168 try
147169 codomain = eval (Meta. parse (d[:codomain ]))
148170 domain = eval (Meta. parse (d[:domain ]))
@@ -151,8 +173,10 @@ function Base.convert(::Type{ROCTensorMap}, d::Dict{Symbol,Any})
151173 catch e # sector unknown in TensorKit.jl; user-defined, hopefully accessible in Main
152174 codomain = Base. eval (Main, Meta. parse (d[:codomain ]))
153175 domain = Base. eval (Main, Meta. parse (d[:domain ]))
154- data = SectorDict (Base. eval (Main, Meta. parse (c)) => ROCArray (b)
155- for (c, b) in d[:data ])
176+ data = SectorDict (
177+ Base. eval (Main, Meta. parse (c)) => ROCArray (b)
178+ for (c, b) in d[:data ]
179+ )
156180 return ROCTensorMap (data, codomain, domain)
157181 end
158182end
@@ -164,22 +188,24 @@ end
164188# Scalar implementation
165189# -----------------------
166190function TensorKit. scalar (t:: ROCTensorMap )
167-
191+
168192 # TODO : should scalar only work if N₁ == N₂ == 0?
169193 return @allowscalar dim (codomain (t)) == dim (domain (t)) == 1 ?
170- first (blocks (t))[2 ][1 , 1 ] : throw (DimensionMismatch ())
194+ first (blocks (t))[2 ][1 , 1 ] : throw (DimensionMismatch ())
171195end
172196
173197TensorKit. scalartype (A:: StridedROCArray{T} ) where {T} = T
174198vi_scalartype (:: Type{<:ROCTensorMap{T}} ) where {T} = T
175199vi_scalartype (:: Type{<:ROCArray{T}} ) where {T} = T
176200
177- function TensorKit. similarstoragetype (TT:: Type{<:ROCTensorMap{TTT,S, N₁,N₂}} , :: Type{T} ) where {TTT,T,S, N₁,N₂}
201+ function TensorKit. similarstoragetype (TT:: Type{<:ROCTensorMap{TTT, S, N₁, N₂}} , :: Type{T} ) where {TTT, T, S, N₁, N₂}
178202 return ROCVector{T, AMDGPU. Mem. HIPBuffer}
179203end
180204
181- function Base. convert (TT:: Type{ROCTensorMap{T,S,N₁,N₂}} ,
182- t:: AbstractTensorMap{<:Any,S,N₁,N₂} ) where {T,S,N₁,N₂}
205+ function Base. convert (
206+ TT:: Type{ROCTensorMap{T, S, N₁, N₂}} ,
207+ t:: AbstractTensorMap{<:Any, S, N₁, N₂}
208+ ) where {T, S, N₁, N₂}
183209 if typeof (t) === TT
184210 return t
185211 else
@@ -204,12 +230,16 @@ function Base.copy!(tdst::ROCTensorMap, tsrc::TensorKit.AdjointTensorMap)
204230 return tdst
205231end
206232
207- function Base. promote_rule (:: Type{<:TT₁} ,
208- :: Type{<:TT₂} ) where {S,N₁,N₂, TTT₁, TTT₂,
209- TT₁<: ROCTensorMap{TTT₁,S,N₁,N₂} ,
210- TT₂<: ROCTensorMap{TTT₂,S,N₁,N₂} }
233+ function Base. promote_rule (
234+ :: Type{<:TT₁} ,
235+ :: Type{<:TT₂}
236+ ) where {
237+ S, N₁, N₂, TTT₁, TTT₂,
238+ TT₁ <: ROCTensorMap{TTT₁, S, N₁, N₂} ,
239+ TT₂ <: ROCTensorMap{TTT₂, S, N₁, N₂} ,
240+ }
211241 T = TensorKit. VectorInterface. promote_add (TTT₁, TTT₂)
212- return ROCTensorMap{T,S, N₁,N₂}
242+ return ROCTensorMap{T, S, N₁, N₂}
213243end
214244
215245function LinearAlgebra. isposdef (t:: ROCTensorMap )
@@ -218,7 +248,7 @@ function LinearAlgebra.isposdef(t::ROCTensorMap)
218248 InnerProductStyle (spacetype (t)) === EuclideanInnerProduct () || return false
219249 for (c, b) in blocks (t)
220250 # do our own hermitian check
221- isherm = TensorKit. MatrixAlgebraKit. ishermitian (b; atol= eps (real (eltype (b))), rtol= eps (real (eltype (b))))
251+ isherm = TensorKit. MatrixAlgebraKit. ishermitian (b; atol = eps (real (eltype (b))), rtol = eps (real (eltype (b))))
222252 isherm || return false
223253 isposdef (Hermitian (b)) || return false
224254 end
0 commit comments