Skip to content

Commit cc9eace

Browse files
authored
Improve the YaoToEinsum visualization (#567)
* annotate output as gray, fix xor utf8 rendering * fix tests
1 parent 99fa2a5 commit cc9eace

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

lib/YaoToEinsum/ext/LuxorExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ function YaoToEinsum.viznet(tn::TensorNetwork; scale=100, filename=nothing, dual
7575
graph = SimpleGraph(length(label_coos) + length(tensor_coos)) # the first batch of vertices are for labels
7676
label2idx = Dict(zip(labels, 1:length(labels)))
7777
ixs = OMEinsum.getixsv(tn.code)
78+
iy = OMEinsum.getiyv(tn.code)
7879
for (i, ix) in enumerate(ixs)
7980
for label in ix
8081
add_edge!(graph, label2idx[label], i + length(labels))
@@ -83,7 +84,7 @@ function YaoToEinsum.viznet(tn::TensorNetwork; scale=100, filename=nothing, dual
8384

8485
locs = [Tuple(coo .* scale) for coo in vcat(label_coos, tensor_coos)] # flip x-y axis
8586
vertex_shapes = [i <= length(labels) ? :circle : :circle for i in 1:length(locs)] # box for variables
86-
vertex_colors = [i <= length(labels) ? (labels[i] > 0 ? "hotpink" : "lawngreen") : "transparent" for i in 1:length(locs)]
87+
vertex_colors = [i <= length(labels) ? (labels[i] > 0 ? (labels[i] iy ? "lightgray" : "hotpink") : (labels[i] iy ? "lightgray" : "lawngreen")) : "transparent" for i in 1:length(locs)]
8788
vertex_stroke_colors = [i <= length(labels) ? "transparent" : "black" for i in 1:length(locs)]
8889
vertex_sizes = [i <= length(labels) ? 8 : (length(ixs[i-length(labels)]) == 1 ? 6 : node_size) for i in 1:length(locs)]
8990
texts = [i <= length(labels) ? string(labels[i]) : special_tensor_detection(tn.tensors[i-length(labels)]) for i in 1:length(locs)]
@@ -139,7 +140,7 @@ function special_tensor_detection(t::AbstractArray)
139140
elseif isdelta(t)
140141
return "δ"
141142
elseif isxor(t)
142-
return ""
143+
return "+"
143144
end
144145
return ""
145146
end

lib/YaoToEinsum/test/LuxorExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ using YaoToEinsum, Test, YaoToEinsum.OMEinsum, LuxorGraphPlot, YaoToEinsum.YaoBl
5353
# Test XOR tensors (2x2x...x2 tensors with XOR pattern)
5454
xor_tensor = zeros(2, 2, 2)
5555
xor_tensor[1, 1, 1] = xor_tensor[1,2,2] = xor_tensor[2,1,2] = xor_tensor[2,2,1] = 1 # even number of 2's (0)
56-
@test ext.special_tensor_detection(xor_tensor) == ""
56+
@test ext.special_tensor_detection(xor_tensor) == "+"
5757

5858
# Test unrecognized tensors
5959
@test ext.special_tensor_detection([1, 2, 3]) == ""
@@ -84,4 +84,4 @@ end
8484

8585
tn = yao2einsum(c; initial_state, observable=put(3, 2=>X), mode=PauliBasisMode())
8686
@test viznet(tn) isa LuxorGraphPlot.Drawing
87-
end
87+
end

0 commit comments

Comments
 (0)