Skip to content

Commit 32cc2d3

Browse files
authored
Merge branch 'master' into ksh/cutensor_bump
2 parents b41d359 + 5d35cf0 commit 32cc2d3

File tree

6 files changed

+57
-37
lines changed

6 files changed

+57
-37
lines changed

.github/workflows/FormatCheck.yml

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
name: FormatCheck
1+
name: 'Format'
22

33
on:
4-
push:
5-
branches:
6-
- 'main'
7-
- 'master'
8-
tags: '*'
9-
pull_request:
10-
branches:
11-
- 'main'
12-
- 'master'
4+
pull_request_target:
5+
paths: ['**/*.jl']
6+
types: [opened, synchronize, reopened, ready_for_review]
7+
8+
permissions:
9+
contents: read
10+
actions: write
11+
pull-requests: write
1312

1413
jobs:
1514
formatcheck:
16-
name: "Format Check"
17-
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/FormatCheck.yml@main"
15+
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/FormatCheck.yml@main"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "5.4.0"
4+
version = "5.3.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

docs/src/man/indexnotation.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,23 @@ however different strategies to modify this order.
243243
because they share an index label which is next in the `order` list, all other indices
244244
with shared label among them will be contracted, irrespective of their order.
245245

246+
!!! warning "Combining order specifications"
247+
248+
Note that it is currently not possible to combine the NCON style convention of specifying
249+
indices with the use of parentheses. If both are used at the same time, the parentheses
250+
take precedence and the NCON style will be ignored. Any remaining contraction orders
251+
will be evaluated in the default left to right order. For example, in the following
252+
contractions, we have `E1 = A * ((B * D) * C)`, but `E2 = A * ((B * C) * D)`. This is
253+
true even when the parentheses are compatible with the NCON contraction order, as is
254+
the case here.
255+
256+
```julia
257+
@tensor E1[-1 -2 -3; -4] := A[-1 -2 -3; 4 5] * B[4; 1] * C[5; 2] * D[1 2; -4]
258+
@tensor E2[-1 -2 -3; -4] := A[-1 -2 -3; 4 5] * (B[4; 1] * C[5; 2] * D[1 2; -4])
259+
```
260+
261+
Additionally, combining the `order = (...)` keyword with parentheses is currently not supported.
262+
246263
In the case of more complex tensor networks, the optimal contraction order cannot always
247264
easily be guessed or determined on plain sight. It is then useful to be able to optimize the
248265
contraction order automatically, given a model for the complexity of contracting the

ext/TensorOperationscuTENSORExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
#-------------------------------------------------------------------------------------------
148148
# Allocator
149149
#-------------------------------------------------------------------------------------------
150-
function CUDAAllocator()
150+
function TO.CUDAAllocator()
151151
Mout = CUDA.UnifiedMemory
152152
Min = CUDA.default_memory
153153
Mtemp = CUDA.default_memory

src/indexnotation/instantiators.jl

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,15 @@ function instantiate_generaltensor(
122122
β = βsym
123123
end
124124
if alloc (NewTensor, TemporaryTensor)
125-
TC = gensym("T_" * string(dst))
125+
TCsym = gensym("T_" * string(dst))
126126
istemporary = Val(alloc === TemporaryTensor)
127-
if scaltype === nothing
128-
TCval = α === One() ? instantiate_scalartype(src) :
129-
instantiate_scalartype(Expr(:call, :*, α, src))
130-
else
131-
TCval = scaltype
132-
end
133-
push!(out.args, Expr(:(=), TC, TCval))
127+
TCval = @something(
128+
scaltype, instantiate_scalartype=== One() ? src : Expr(:call, :*, α, src))
129+
)
130+
push!(out.args, Expr(:(=), TCsym, TCval))
134131
push!(
135132
out.args,
136-
Expr(:(=), dst, :(tensoralloc_add($TC, $src, $p, $conj, $istemporary)))
133+
Expr(:(=), dst, :(tensoralloc_add($TCsym, $src, $p, $conj, $istemporary)))
137134
)
138135
end
139136

@@ -167,9 +164,9 @@ function instantiate_linearcombination(
167164
)
168165
out = Expr(:block)
169166
if alloc (NewTensor, TemporaryTensor)
170-
if scaltype === nothing
171-
scaltype = instantiate_scalartype(ex)
172-
end
167+
scaltype = @something(
168+
scaltype, instantiate_scalartype=== One() ? ex : Expr(:call, :*, α, ex))
169+
)
173170
push!(
174171
out.args,
175172
instantiate(dst, β, ex.args[2], α, leftind, rightind, alloc, scaltype)
@@ -275,18 +272,15 @@ function instantiate_contraction(
275272
end
276273
if alloc (NewTensor, TemporaryTensor)
277274
TCsym = gensym("T_" * string(dst))
278-
if scaltype === nothing
279-
Atype = instantiate_scalartype(A)
280-
Btype = instantiate_scalartype(B)
281-
TCval = Expr(:call, :promote_contract, Atype, Btype)
282-
if α !== One()
283-
TCval = Expr(
284-
:call, :(Base.promote_op), :*, instantiate_scalartype(α), TCval
285-
)
275+
TCval = @something(
276+
scaltype,
277+
begin
278+
TA = instantiate_scalartype(A)
279+
TB = instantiate_scalartype(B)
280+
TAB = :(promote_contract($TA, $TB))
281+
α === One() ? TAB : :(Base.promote_op(*, $(instantiate_scalartype(α)), $TAB))
286282
end
287-
else
288-
TCval = scaltype
289-
end
283+
)
290284
istemporary = Val(alloc === TemporaryTensor)
291285
initC = Expr(
292286
:block, Expr(:(=), TCsym, TCval),

test/tensor.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,4 +581,15 @@ end
581581
@test isblascontractable(pA, p)
582582
@test isblascontractable(conj(pA), p)
583583
end
584+
585+
@testset "Issue 220" begin
586+
A = rand(2, 2)
587+
B = rand(2, 2)
588+
C = rand(2, 2)
589+
D = rand(2, 2)
590+
c = 1im
591+
@tensor E[a; c] := c * (A[a b] * B[b c] + C[a b] * D[b c])
592+
@test scalartype(E) == ComplexF64
593+
@test E c * (A * B + C * D)
594+
end
584595
end

0 commit comments

Comments
 (0)