Skip to content

Commit 3433bc9

Browse files
authored
Fix an issue with unsorted deleteat (#186)
* Fix an issue with unsorted `deleteat` * Add testcase
1 parent e5651a0 commit 3433bc9

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

ext/TensorOperationscuTENSORExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,19 +278,19 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::ModeType,
278278
# TODO: check if this can be avoided, available in caller
279279
# TODO: cuTENSOR will allocate sizes and strides anyways, could use that here
280280
p, q = TO.trace_indices(tuple(Ainds...), tuple(Cinds...))
281-
281+
qsorted = TT.sort(q[2])
282282
# add strides of cindA2 to strides of cindA1 -> selects diagonal
283283
stA = strides(A)
284284
for (i, j) in zip(q...)
285285
stA = Base.setindex(stA, stA[i] + stA[j], i)
286286
end
287-
szA = TT.deleteat(size(A), q[2])
288-
stA′ = TT.deleteat(stA, q[2])
287+
szA = TT.deleteat(size(A), qsorted)
288+
stA′ = TT.deleteat(stA, qsorted)
289289

290290
descA = CuTensorDescriptor(A; size=szA, strides=stA′)
291291
descC = CuTensorDescriptor(C)
292292

293-
modeA = collect(Cint, deleteat!(Ainds, q[2]))
293+
modeA = collect(Cint, deleteat!(Ainds, qsorted))
294294
modeC = collect(Cint, Cinds)
295295

296296
actual_compute_type = if compute_type === nothing

test/cutensor.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,14 @@ if cuTENSOR.has_cutensor()
336336
@test copy(C) Ccopy
337337
end
338338
end
339+
340+
@testset "Issues" verbose = true begin
341+
@testset "Issue PR #186" begin
342+
# https://github.com/Jutho/TensorOperations.jl/pull/186
343+
A = randn(Float32, (5, 5, 5, 5))
344+
Atr = @tensor A[a, b, b, a]
345+
Atr2 = @cutensor A[a, b, b, a]
346+
@test Atr Atr2
347+
end
348+
end
339349
end

0 commit comments

Comments
 (0)