Skip to content

Commit f325011

Browse files
committed
unify handling of tensor_product(dual)
1 parent a36f4d7 commit f325011

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

NDTensors/src/lib/GradedAxes/src/fusion.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ function tensor_product(
1919
return foldl(tensor_product, (a1, a2, a3, a_rest...))
2020
end
2121

22-
function tensor_product(::AbstractUnitRange, ::AbstractUnitRange)
23-
return error("Not implemented yet.")
22+
flip_dual(r::AbstractUnitRange) = r
23+
flip_dual(r::UnitRangeDual) = flip(r)
24+
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
25+
return tensor_product(flip_dual(a1), flip_dual(a2))
2426
end
2527

2628
function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
@@ -39,20 +41,6 @@ function tensor_product(::OneToOne, ::OneToOne)
3941
return OneToOne()
4042
end
4143

42-
# Handle dual. Always return a non-dual GradedUnitRange.
43-
function tensor_product(a1::AbstractUnitRange, a2::UnitRangeDual)
44-
return tensor_product(a1, flip(a2))
45-
end
46-
47-
function tensor_product(a1::UnitRangeDual, a2::AbstractUnitRange)
48-
return tensor_product(flip(a1), a2)
49-
end
50-
51-
# TBD change convention to tensor(dual, dual) -> dual?
52-
function tensor_product(a1::UnitRangeDual, a2::UnitRangeDual)
53-
return tensor_product(flip(a1), flip(a2))
54-
end
55-
5644
function fuse_labels(x, y)
5745
return error(
5846
"`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`."

0 commit comments

Comments
 (0)