Skip to content

Commit 837e151

Browse files
committed
fix names in array_cast
1 parent ee06e3f commit 837e151

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

src/fusiontensor/array_cast.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,34 @@ function FusionTensor(
1515
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
1616
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
1717
)
18-
return cast_from_array(array, codomain_legs, domain_legs)
18+
return to_fusiontensor(array, codomain_legs, domain_legs)
1919
end
2020

2121
#### cast from symmetric to array
2222
function BlockSparseArrays.BlockSparseArray(ft::FusionTensor)
23-
return cast_to_array(ft)
23+
return to_array(ft)
2424
end
2525

2626
# ================================= Low level interface ==================================
27-
function cast_from_array(
27+
function to_fusiontensor(
2828
array::AbstractArray,
2929
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
3030
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
3131
)
3232
bounds = block_dimensions.((codomain_legs..., domain_legs...))
3333
blockarray = BlockedArray(array, bounds...)
34-
return cast_from_array(blockarray, codomain_legs, domain_legs)
34+
return to_fusiontensor(blockarray, codomain_legs, domain_legs)
3535
end
3636

37-
function cast_from_array(
37+
get_tol(a::AbstractArray) = get_tol(real(eltype(a)))
38+
get_tol(T::Type{<:Integer}) = get_tol(Float64)
39+
get_tol(T::Type{<:Real}) = 10 * eps(T)
40+
41+
function to_fusiontensor(
3842
blockarray::AbstractBlockArray,
3943
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
4044
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}};
41-
tol::Float64=1e-12,
45+
tol::Real=get_tol(blockarray),
4246
)
4347
# input validation
4448
if length(codomain_legs) + length(domain_legs) != ndims(blockarray) # compile time
@@ -48,20 +52,26 @@ function cast_from_array(
4852
throw(DomainError("legs dimensions are incompatible with array"))
4953
end
5054

51-
ft = unsafe_cast_from_array(blockarray, codomain_legs, domain_legs)
55+
ft = to_fusiontensor_no_checknorm(blockarray, codomain_legs, domain_legs)
5256

5357
# if blockarray is not G-invariant, norm(ft) < norm(blockarray)
54-
if abs(norm(ft) - norm(blockarray)) > tol
58+
checknorm(ft, blockarray, tol)
59+
return ft
60+
end
61+
62+
function checknorm(ft::FusionTensor, a::AbstractArray, tol::Real)
63+
n0 = norm(a)
64+
if abs(norm(ft) - n0) > tol * n0
5565
throw(
5666
InexactError(
57-
:FusionTensor, typeof(blockarray), typeof(codomain_legs), typeof(domain_legs)
67+
:FusionTensor, typeof(a), typeof(codomain_axes(ft)), typeof(domain_axes(ft))
5868
),
5969
)
6070
end
6171
return ft
6272
end
6373

64-
function unsafe_cast_from_array(
74+
function to_fusiontensor_no_checknorm(
6575
blockarray::AbstractBlockArray,
6676
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
6777
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
@@ -74,7 +84,7 @@ function unsafe_cast_from_array(
7484
return ft
7585
end
7686

77-
function cast_to_array(ft::FusionTensor)
87+
function to_array(ft::FusionTensor)
7888
bounds = block_dimensions.((codomain_axes(ft)..., domain_axes(ft)...))
7989
bsa = BlockSparseArray{eltype(ft)}(blockedrange.(bounds))
8090
for (f1, f2) in keys(trees_block_mapping(ft))

src/fusiontensor/base_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function Base.:/(ft::FusionTensor, x::Number)
5353
)
5454
end
5555

56-
Base.Array(ft::FusionTensor) = Array(cast_to_array(ft))
56+
Base.Array(ft::FusionTensor) = Array(to_array(ft))
5757

5858
# adjoint is costless: dual axes, swap codomain and domain, take data_matrix adjoint.
5959
# data_matrix coeff are not modified (beyond complex conjugation)

0 commit comments

Comments
 (0)