Skip to content

Commit 563950b

Browse files
committed
update
1 parent b32bdfb commit 563950b

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/map.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,19 @@ $(TYPEDSIGNATURES)
5353
Returns the largest log-probability and the most probable configuration.
5454
"""
5555
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
56-
ixs = OMEinsum.getixsv(tn.code)
57-
unity_labels = ixs[tn.unity_tensors_idx]
58-
indices = [findfirst(==([l]), unity_labels) for l in get_vars(tn)]
59-
@assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables"
6056
vars = get_vars(tn)
57+
tensor_indices = check_queryvars(tn, [[v] for v in vars])
6158
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
6259
logp, grads = cost_and_gradient(tn.code, tensors)
6360
# use Array to convert CuArray to CPU arrays
64-
return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tn.unity_tensors_idx[indices[k]]]) - 1, 1:length(vars))
61+
return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tensor_indices[k]]) - 1, 1:length(vars))
62+
end
63+
# check if the queryvars are included in the unity tensors labels, if yes, return the indices of the unity tensors
64+
function check_queryvars(tn::TensorNetworkModel, queryvars::AbstractVector{Vector{Int}})
65+
ixs = OMEinsum.getixsv(tn.code)
66+
indices = [findfirst(==(l), ixs[tn.unity_tensors_idx]) for l in queryvars]
67+
@assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables. Query variables: $queryvars, Unity tensors labels: $(ixs[tn.unity_tensors_idx])"
68+
return tn.unity_tensors_idx[indices]
6569
end
6670

6771
"""

src/mar.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ tensor network.
136136
137137
### Arguments
138138
- `tn`: The [`TensorNetworkModel`](@ref) to query.
139-
- `usecuda`: Specifies whether to use CUDA for tensor contraction.
140-
- `rescale`: Specifies whether to rescale the tensors during contraction.
139+
140+
### Keyword Arguments
141+
- `usecuda::Bool`: Specifies whether to use CUDA for tensor contraction.
142+
- `rescale::Bool`: Specifies whether to rescale the tensors during contraction.
141143
142144
### Example
143145
The following example is taken from [`examples/asia-network/main.jl`](https://tensorbfs.github.io/TensorInference.jl/dev/generated/asia-network/main/).
@@ -187,9 +189,10 @@ function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dic
187189
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
188190
@debug "cost = $cost"
189191
ixs = OMEinsum.getixsv(tn.code)
192+
queryvars = ixs[tn.unity_tensors_idx]
190193
if rescale
191-
return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.unity_tensors_idx)], :normalized_value), 1)))
194+
return Dict(zip(queryvars, LinearAlgebra.normalize!.(getfield.(grads[tn.unity_tensors_idx], :normalized_value), 1)))
192195
else
193-
return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(grads[1:length(tn.unity_tensors_idx)], 1)))
196+
return Dict(zip(queryvars, LinearAlgebra.normalize!.(grads[tn.unity_tensors_idx], 1)))
194197
end
195-
end
198+
end

0 commit comments

Comments
 (0)