Skip to content

Commit 54f2672

Browse files
authored
Make adjointtensormap generic (#148)
1 parent 9e78a03 commit 54f2672

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Features that are planned to be implemented before the release of v1.0.0, in no
66
- [ ] Make `TrivialTensorMap` and `TensorMap` be the same
77
- [ ] Simplify `TensorMap` type to hide `rowr` and `colr`
88
- [ ] Change block order in `rowr` / `colr` to speed up particular contractions
9-
- [ ] Make `AdjointTensorMap` generic
9+
- [x] Make `AdjointTensorMap` generic
1010
- [ ] Rewrite planar operations in order to be AD-compatible
1111
- [x] Fix rrules for fermionic tensors
1212
- [ ] Fix GPU support

src/tensors/adjoint.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,83 @@
11
# AdjointTensorMap: lazy adjoint
22
#==========================================================#
33
"""
4-
struct AdjointTensorMap{T, S, N₁, N₂, ...} <: AbstractTensorMap{T, S, N₁, N₂}
4+
struct AdjointTensorMap{T, S, N₁, N₂, TT<:AbstractTensorMap} <: AbstractTensorMap{T, S, N₁, N₂}
55
66
Specific subtype of [`AbstractTensorMap`](@ref) that is a lazy wrapper for representing the
7-
adjoint of an instance of [`TensorMap`](@ref).
7+
adjoint of an instance of [`AbstractTensorMap`](@ref).
88
"""
9-
struct AdjointTensorMap{T,S,N₁,N₂,I,A,F₁,F₂} <:
9+
struct AdjointTensorMap{T,S,N₁,N₂,TT<:AbstractTensorMap{T,S,N₂,N₁}} <:
1010
AbstractTensorMap{T,S,N₁,N₂}
11-
parent::TensorMap{T,S,N₂,N₁,I,A,F₂,F₁}
11+
parent::TT
1212
end
1313

1414
#! format: off
15-
const AdjointTrivialTensorMap{T,S,N₁,N₂,A<:DenseMatrix} =
16-
AdjointTensorMap{T,S,N₁,N₂,Trivial,A,Nothing,Nothing}
15+
const AdjointTrivialTensorMap{T,S,N₁,N₂,TT<:TrivialTensorMap{T,S,N₂,N₁}} =
16+
AdjointTensorMap{T,S,N₁,N₂,TT}
1717
#! format: on
1818

1919
# Constructor: construct from taking adjoint of a tensor
20-
Base.adjoint(t::TensorMap) = AdjointTensorMap(t)
21-
Base.adjoint(t::AdjointTensorMap) = t.parent
20+
Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t)
21+
Base.adjoint(t::AdjointTensorMap) = parent(t)
22+
23+
Base.parent(t::AdjointTensorMap) = t.parent
24+
parenttype(::Type{<:AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT} = TT
2225

2326
function Base.similar(t::AdjointTensorMap, ::Type{TorA},
2427
P::TensorMapSpace) where {TorA<:MatOrNumber}
2528
return similar(t', TorA, P)
2629
end
2730

2831
# Properties
29-
codomain(t::AdjointTensorMap) = domain(t.parent)
30-
domain(t::AdjointTensorMap) = codomain(t.parent)
32+
codomain(t::AdjointTensorMap) = domain(parent(t))
33+
domain(t::AdjointTensorMap) = codomain(parent(t))
3134

32-
blocksectors(t::AdjointTensorMap) = blocksectors(t.parent)
35+
blocksectors(t::AdjointTensorMap) = blocksectors(parent(t))
3336

34-
#! format: off
35-
storagetype(::Type{<:AdjointTrivialTensorMap{T,S,N₁,N₂,A}}) where {T,S,N₁,N₂,A<:DenseMatrix} = A
36-
storagetype(::Type{<:AdjointTensorMap{T,S,N₁,N₂,I,<:SectorDict{I,A}}}) where {T,S,N₁,N₂,I<:Sector,A<:DenseMatrix} = A
37-
#! format: on
37+
storagetype(::Type{TT}) where {TT<:AdjointTensorMap} = storagetype(parenttype(TT))
3838

39-
dim(t::AdjointTensorMap) = dim(t.parent)
39+
dim(t::AdjointTensorMap) = dim(parent(t))
4040

4141
# Indexing
4242
#----------
43-
hasblock(t::AdjointTensorMap, s::Sector) = hasblock(t.parent, s)
44-
block(t::AdjointTensorMap, s::Sector) = block(t.parent, s)'
45-
blocks(t::AdjointTensorMap) = (c => b' for (c, b) in blocks(t.parent))
43+
hasblock(t::AdjointTensorMap, s::Sector) = hasblock(parent(t), s)
44+
block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)'
45+
blocks(t::AdjointTensorMap) = (c => b' for (c, b) in blocks(parent(t)))
4646

4747
fusiontrees(::AdjointTrivialTensorMap) = ((nothing, nothing),)
48-
fusiontrees(t::AdjointTensorMap) = TensorKeyIterator(t.parent.colr, t.parent.rowr)
48+
function fusiontrees(t::AdjointTensorMap{T,S,N₁,N₂,TT}) where {T,S,N₁,N₂,TT<:TensorMap}
49+
return TensorKeyIterator(parent(t).colr, parent(t).rowr)
50+
end
4951

50-
function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂,I},
52+
function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂,<:TensorMap{T,S,N₁,N₂,I}},
5153
f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I}
5254
c = f₁.coupled
5355
@boundscheck begin
5456
c == f₂.coupled || throw(SectorMismatch())
5557
hassector(codomain(t), f₁.uncoupled) && hassector(domain(t), f₂.uncoupled)
5658
end
57-
return sreshape((StridedView(t.parent.data[c])[t.parent.rowr[c][f₂],
58-
t.parent.colr[c][f₁]])',
59+
return sreshape((StridedView(parent(t).data[c])[parent(t).rowr[c][f₂],
60+
parent(t).colr[c][f₁]])',
5961
(dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled)...))
6062
end
63+
@propagate_inbounds function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂},
64+
f₁::FusionTree{I,N₁},
65+
f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I}
66+
d_cod = dims(codomain(t), f₁.uncoupled)
67+
d_dom = dims(domain(t), f₂.uncoupled)
68+
return sreshape(sreshape(StridedView(parent(t)[f₂, f₁]), (prod(d_dom), prod(d_cod)))',
69+
(d_cod..., d_dom...))
70+
end
71+
6172
@propagate_inbounds function Base.setindex!(t::AdjointTensorMap{T,S,N₁,N₂,I}, v,
6273
f₁::FusionTree{I,N₁},
6374
f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I}
6475
return copy!(getindex(t, f₁, f₂), v)
6576
end
6677

6778
@inline function Base.getindex(t::AdjointTrivialTensorMap)
68-
return sreshape(StridedView(t.parent.data)', (dims(codomain(t))..., dims(domain(t))...))
79+
return sreshape(StridedView(parent(t).data)',
80+
(dims(codomain(t))..., dims(domain(t))...))
6981
end
7082
@inline Base.setindex!(t::AdjointTrivialTensorMap, v) = copy!(getindex(t), v)
7183

0 commit comments

Comments
 (0)