|
1 | | -@kwdef mutable struct OrthogonalLinearFormProblem{State, LinearFormNetwork} |
| 1 | +@kwdef mutable struct FittingProblem{State, OverlapNetwork} |
2 | 2 | state::State |
3 | | - linearformnetwork::LinearFormNetwork |
| 3 | + overlapnetwork::OverlapNetwork |
4 | 4 | squared_scalar::Number = 0 |
5 | 5 | end |
6 | 6 |
|
7 | | -squared_scalar(O::OrthogonalLinearFormProblem) = O.squared_scalar |
8 | | -state(O::OrthogonalLinearFormProblem) = O.state |
9 | | -linearformnetwork(O::OrthogonalLinearFormProblem) = O.linearformnetwork |
| 7 | +squared_scalar(F::FittingProblem) = F.squared_scalar |
| 8 | +state(F::FittingProblem) = F.state |
| 9 | +overlapnetwork(F::FittingProblem) = F.overlapnetwork |
10 | 10 |
|
11 | | -function set(O::OrthogonalLinearFormProblem; state = state(O), linearformnetwork = linearformnetwork(O), squared_scalar = squared_scalar(O)) |
12 | | - return OrthogonalLinearFormProblem(; state, linearformnetwork, squared_scalar) |
| 11 | +function set(F::FittingProblem; state = state(F), overlapnetwork = overlapnetwork(F), squared_scalar = squared_scalar(F)) |
| 12 | + return FittingProblem(; state, linearformnetwork, squared_scalar) |
13 | 13 | end |
14 | 14 |
|
15 | | -function updater!(O::OrthogonalLinearFormProblem, local_tensor, region; outputlevel, kws...) |
16 | | - O.squared_scalar, local_tensor = linearform_updater |
17 | | - |
18 | | -function maximize_linearformnetwork_sq(linearformnetwork, init_state; nsweeps, nsites=1, outputlevel = 0, update_kwargs = (;), inserter_kwargs = (;), kws...) |
19 | | - init_prob = OrthogonalLinearFormProblem(; state = copy(init_state), linearformnetwork = linearformnetwork) |
| 15 | +function fit_tensornetwork(tn::AbstractITensorNetwork, init_state::AbstractITensorNetwork, vertex_partitioning) |
| 16 | + overlap_bpc = BeliefPropagationCache(inner_network(tn, init_state), vertex_partitioning) |
| 17 | + init_prob = FittingProblem(; state = copy(init_state), overlapnetwork = overlap_bpc) |
20 | 18 | common_sweep_kwargs = (; nsites, outputlevel, updater_kwargs, inserter_kwargs) |
21 | 19 | kwargs_array = [(; common_sweep_kwargs..., sweep = s) for s in 1:nsweeps] |
22 | 20 | sweep_iter = sweep_iterator(init_prob, kwargs_array) |
|
0 commit comments