Skip to content
13 changes: 10 additions & 3 deletions src/tensor_product.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# This files defines an interface for the tensor product of two axes
# https://en.wikipedia.org/wiki/Tensor_product

# ================================== misc ================================================
is_offset_axis(a::AbstractUnitRange) = isone(first(a))

function require_one_based_axis(a::AbstractUnitRange)
return !is_offset_axis(a) && throw(ArgumentError("Range must be one-based"))
end

# ============================== tensor product ==========================================
⊗() = tensor_product()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be an alias:

Suggested change
() = tensor_product()
const = tensor_product

Then you could define methods for either, instead of both

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not compatible with being either tensor_product or fusion_product depending on the input.

⊗(a) = tensor_product(a)

Expand All @@ -16,9 +24,8 @@ tensor_product(a1, a2, as...) = tensor_product(tensor_product(a1, a2), as...)

# default
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
!(isone(first(a1)) && isone(first(a2))) &&
throw(ArgumentError("Ranges must be one-based"))
return Base.OneTo(prod(length.((a1, a2))))
require_one_based_axis(a1) || require_one_based_axis(a2)
return Base.OneTo(length(a1) * length(a2))
end

tensor_product(::OneToOne, ::OneToOne) = OneToOne()
Loading