Skip to content

Commit 7be6c1e

Browse files
committed
Add SNAPHU MCF initializer based on or-tools
1 parent 1101ab9 commit 7be6c1e

File tree

6 files changed

+387
-9
lines changed

6 files changed

+387
-9
lines changed

cxx/isce3/unwrap/snaphu/snaphu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,8 @@ int UnwrapTile(infileT *infiles, outfileT *outfiles, paramT *params,
517517

518518
}else if(params->initmethod==MCFINIT){
519519

520-
fflush(NULL);
521-
throw isce3::except::InvalidArgument(ISCE_SRCINFO(),
522-
"MCF initialization not implemented");
520+
/* use minimum cost flow (MCF) algorithm */
521+
MCFInitFlows(wrappedphase,&flows,mstcosts,nrow,ncol);
523522

524523
}else{
525524
fflush(NULL);
@@ -725,7 +724,6 @@ int UnwrapTile(infileT *infiles, outfileT *outfiles, paramT *params,
725724
/* flip the sign of the unwrapped phase array if it was flipped initially, */
726725
FlipPhaseArraySign(unwrappedphase,params,nrow,ncol);
727726

728-
729727
/* write the unwrapped output */
730728
fprintf(sp1,"Writing output to file %s\n",outfiles->outfile);
731729
WriteOutputFile(mag,unwrappedphase,outfiles->outfile,outfiles,

cxx/isce3/unwrap/snaphu/snaphu.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
#include <isce3/except/Error.h>
2121

22-
namespace isce3::unwrap {
23-
2422
/**********************/
2523
/* defined constants */
2624
/**********************/
@@ -129,6 +127,7 @@ namespace isce3::unwrap {
129127
#define NARMS 8 /* number of arms for Despeckle() */
130128
#define ARMLEN 5 /* length of arms for Despeckle() */
131129
#define KEDGE 5 /* length of edge detection window */
130+
#define ARCUBOUND 200 /* capacities for MCF solver */
132131
#define MSTINIT 1 /* initialization method */
133132
#define MCFINIT 2 /* initialization method */
134133
#define BIGGESTDZRHOMAX 10000.0
@@ -419,6 +418,8 @@ namespace isce3::unwrap {
419418
"\n"
420419

421420

421+
namespace isce3::unwrap {
422+
422423
/********************/
423424
/* type definitions */
424425
/********************/
@@ -823,6 +824,8 @@ totalcostT EvaluateTotalCost(Array2D<typename CostTag::Cost>& costs, Array2D<sho
823824
int MSTInitFlows(Array2D<float>& wrappedphase, Array2D<short>* flowsptr,
824825
Array2D<short>& mstcosts, long nrow, long ncol,
825826
Array2D<nodeT>* nodes, nodeT *ground, long maxflow);
827+
int MCFInitFlows(Array2D<float>& wrappedphase, Array2D<short>* flowsptr, Array2D<short>& mstcosts,
828+
long nrow, long ncol);
826829

827830

828831
/* functions in snaphu_cost.c */

cxx/isce3/unwrap/snaphu/snaphu_solver.cpp

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include <cstdlib>
1212
#include <cstring>
13+
#include <limits>
1314

1415
#include <isce3/except/Error.h>
16+
#include <isce3/unwrap/ortools/min_cost_flow.h>
1517

1618
#include "snaphu.h"
1719

@@ -2350,7 +2352,7 @@ int InitNetwork(Array2D<short>& flows, long *ngroundarcsptr, long *ncycleptr,
23502352
long i;
23512353

23522354
/* get and initialize memory for nodes */
2353-
if(ground!=NULL && nodesptr->size()){
2355+
if(ground!=NULL && !nodesptr->size()){
23542356
*nodesptr = Array2D<nodeT>(nrow-1, ncol-1);
23552357
InitNodeNums(nrow-1,ncol-1,*nodesptr,ground);
23562358
}
@@ -3673,6 +3675,173 @@ signed char ClipFlow(Array2D<signed char>& residue, Array2D<short>& flows,
36733675
}
36743676

36753677

3678+
/* function: MCFInitFlows()
3679+
* ------------------------
3680+
* Initializes the flow on the network using a minimum cost flow
3681+
* algorithm.
3682+
*/
3683+
int MCFInitFlows(Array2D<float>& wrappedphase, Array2D<short>* flowsptr,
3684+
Array2D<short>& mstcosts, long nrow, long ncol){
3685+
3686+
/* number of rows & cols of nodes in the residue network */
3687+
const auto m=nrow-1;
3688+
const auto n=ncol-1;
3689+
3690+
/* calculate phase residues (integer numbers of cycles) */
3691+
auto residue=Array2D<signed char>(m,n);
3692+
CycleResidue(wrappedphase,residue,nrow,ncol);
3693+
3694+
/* total number of nodes and directed arcs in the network */
3695+
const auto nnodes=m*n+1;
3696+
const auto narcs=2*((m+1)*n+(n+1)*m);
3697+
3698+
/* the solver uses 32-bit integers for node & arc indices */
3699+
/* check for possible overflow */
3700+
using operations_research::NodeIndex;
3701+
using operations_research::ArcIndex;
3702+
if(nnodes>std::numeric_limits<NodeIndex>::max()){
3703+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
3704+
"Number of MCF network nodes exceeds maximum representable value");
3705+
}
3706+
if(narcs>std::numeric_limits<ArcIndex>::max()){
3707+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
3708+
"Number of MCF network arcs exceeds maximum representable value");
3709+
}
3710+
3711+
/* begin building the network topology and setting up the MCF problem */
3712+
using Network=operations_research::SimpleMinCostFlow;
3713+
auto network=Network(nnodes,narcs);
3714+
3715+
/* assigns a positive integer label to each grid node */
3716+
/* grid node indices begin at 1 (index 0 is used for the ground node) */
3717+
auto GetNodeIndex=[=](long i, long j)->NodeIndex{
3718+
return 1+i*n+j;
3719+
};
3720+
constexpr NodeIndex ground=0;
3721+
3722+
/* adds a pair of forward & reverse arcs to the network connecting two nodes */
3723+
/* sister arcs have equal cost and capacity */
3724+
using operations_research::CostValue;
3725+
using operations_research::FlowQuantity;
3726+
auto AddSisterArcs=[&](NodeIndex node1, NodeIndex node2, CostValue cost){
3727+
constexpr static auto capacity=static_cast<FlowQuantity>(ARCUBOUND);
3728+
network.AddArcWithCapacityAndUnitCost(node2,node1,capacity,cost);
3729+
network.AddArcWithCapacityAndUnitCost(node1,node2,capacity,cost);
3730+
};
3731+
3732+
/* break down arc costs into row (horizontal) & col (vertical) cost arrays */
3733+
const auto rowcosts=mstcosts.topLeftCorner(m,n+1);
3734+
const auto colcosts=mstcosts.bottomLeftCorner(m+1,n);
3735+
3736+
/* arcs are assigned sequential indices (starting from 0) in the order that
3737+
they're added to the network */
3738+
/* we rely on this fact later on when extracting flows from the network */
3739+
3740+
/* begin adding horizontal arcs to the network */
3741+
for(long i=0;i<m;++i){
3742+
/* add a pair of arcs between the left border node and the ground node */
3743+
{
3744+
const auto node=GetNodeIndex(i,0);
3745+
const auto cost=static_cast<CostValue>(rowcosts(i,0));
3746+
AddSisterArcs(ground,node,cost);
3747+
}
3748+
3749+
/* add a pair of horizontal arcs between each adjacent grid node */
3750+
for(long j=0;j<n-1;++j){
3751+
const auto node1=GetNodeIndex(i,j);
3752+
const auto node2=GetNodeIndex(i,j+1);
3753+
const auto cost=static_cast<CostValue>(rowcosts(i,j+1));
3754+
AddSisterArcs(node1,node2,cost);
3755+
}
3756+
3757+
/* add a pair of arcs between the right border node and the ground node */
3758+
{
3759+
const auto node=GetNodeIndex(i,n-1);
3760+
const auto cost=static_cast<CostValue>(rowcosts(i,n));
3761+
AddSisterArcs(node,ground,cost);
3762+
}
3763+
}
3764+
3765+
/* begin adding vertical arcs to the network */
3766+
/* add a pair of arcs between each top border node and the ground node */
3767+
for(long j=0;j<n;++j){
3768+
const auto node=GetNodeIndex(0,j);
3769+
const auto cost=static_cast<CostValue>(colcosts(0,j));
3770+
AddSisterArcs(ground,node,cost);
3771+
}
3772+
/* add a pair of vertical arcs between each adjacent grid node */
3773+
for(long i=0;i<m-1;++i){
3774+
for(long j=0;j<n;++j){
3775+
const auto node1=GetNodeIndex(i,j);
3776+
const auto node2=GetNodeIndex(i+1,j);
3777+
const auto cost=static_cast<CostValue>(colcosts(i+1,j));
3778+
AddSisterArcs(node1,node2,cost);
3779+
}
3780+
}
3781+
/* add a pair of arcs between each bottom border node and the ground node */
3782+
for(long j=0;j<n;++j){
3783+
const auto node=GetNodeIndex(m-1,j);
3784+
const auto cost=static_cast<CostValue>(colcosts(m,j));
3785+
AddSisterArcs(node,ground,cost);
3786+
}
3787+
3788+
/* add node supplies to the network */
3789+
FlowQuantity totalsupply=0;
3790+
for(long i=0;i<m;++i){
3791+
for(long j=0;j<n;++j){
3792+
auto node=GetNodeIndex(i,j);
3793+
auto supply=static_cast<FlowQuantity>(residue(i,j));
3794+
network.SetNodeSupply(node,supply);
3795+
totalsupply+=supply;
3796+
}
3797+
}
3798+
3799+
/* add enough demand to the ground node to balance the network */
3800+
network.SetNodeSupply(ground,-totalsupply);
3801+
3802+
/* run the solver to produce L1-optimal flows */
3803+
if(network.Solve() != Network::OPTIMAL){
3804+
throw isce3::except::RuntimeError(ISCE_SRCINFO(),
3805+
"MCF initialization failed");
3806+
}
3807+
3808+
*flowsptr=MakeRowColArray2D<short>(nrow,ncol);
3809+
3810+
/* break down arc flows into row (horizontal) & col (vertical) flow arrays */
3811+
auto rowflows=flowsptr->topLeftCorner(m,n+1);
3812+
auto colflows=flowsptr->bottomLeftCorner(m+1,n);
3813+
3814+
/* extract arc flows from the network */
3815+
/* the easiest way to do this is in the exact order in which the arcs were
3816+
added to the network (relying implicitly on the sequential ordering of arc
3817+
indices) */
3818+
3819+
/* extract horizontal flows from the network */
3820+
ArcIndex arcidx=0;
3821+
for(long i=0;i<m;++i){
3822+
for(long j=0;j<n+1;++j){
3823+
/* Compute eastward-minus-westward net flow */
3824+
const auto x1=network.Flow(arcidx++);
3825+
const auto x2=network.Flow(arcidx++);
3826+
rowflows(i,j)=x2-x1;
3827+
}
3828+
}
3829+
3830+
/* extract vertical flows from the network */
3831+
for(long i=0;i<m+1;++i){
3832+
for(long j=0;j<n;++j){
3833+
/* Compute southward-minus-northward net flow */
3834+
const auto x1=network.Flow(arcidx++);
3835+
const auto x2=network.Flow(arcidx++);
3836+
colflows(i,j)=x2-x1;
3837+
}
3838+
}
3839+
3840+
/* done */
3841+
return(0);
3842+
}
3843+
3844+
36763845
#define INSTANTIATE_TEMPLATES(T) \
36773846
template long TreeSolve(Array2D<nodeT>&, Array2D<nodesuppT>&, nodeT*, \
36783847
nodeT*, Array1D<candidateT>*, \

python/packages/isce3/unwrap/snaphu.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,9 @@ def from_flat_file(
765765
]
766766
CostParams.__doc__ = """SNAPHU cost mode configuration parameters"""
767767

768+
InitMethod = Literal["mst", "mcf"]
769+
InitMethod.__doc__ = """SNAPHU initialization method"""
770+
768771

769772
def unwrap(
770773
unw: isce3.io.gdal.Raster,
@@ -774,6 +777,7 @@ def unwrap(
774777
nlooks: float,
775778
cost: str = "smooth",
776779
cost_params: Optional[CostParams] = None,
780+
init_method: InitMethod = "mcf",
777781
pwr: Optional[isce3.io.gdal.Raster] = None,
778782
mask: Optional[isce3.io.gdal.Raster] = None,
779783
unwest: Optional[isce3.io.gdal.Raster] = None,
@@ -866,6 +870,10 @@ def unwrap(
866870
Configuration parameters for the specified cost mode. This argument is
867871
required for "topo" mode and optional for all other modes. If None, the
868872
default configuration parameters are used. (default: None)
873+
init_method: {"mst", "mcf"}, optional
874+
Algorithm used for initialization of unwrapped phase gradients.
875+
Supported algorithms include Minimum Spanning Tree ("mst") and Minimum
876+
Cost Flow ("mcf"). (default: "mcf")
869877
pwr : isce3.io.gdal.Raster or None, optional
870878
Average intensity of the two SLCs, in linear units (not dB). Only used
871879
in "topo" cost mode. If None, interferogram magnitude is used as
@@ -971,8 +979,14 @@ def cost_string():
971979

972980
configstr += f"STATCOSTMODE {cost_string()}\n"
973981

974-
# XXX Currently, only "MST" initialization method is supported.
975-
configstr += "INITMETHOD MST\n"
982+
def init_string():
983+
if init_method == "mst":
984+
return "MST"
985+
if init_method == "mcf":
986+
return "MCF"
987+
raise ValueError(f"invalid init method '{init_method}'")
988+
989+
configstr += f"INITMETHOD {init_string()}\n"
976990

977991
# Check cost mode-specific configuration params.
978992
if cost == "topo":

tests/cxx/isce3/Sources.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ signal/signal.cpp
7979
signal/signal_utils.cpp
8080
unwrap/icu/icu.cpp
8181
unwrap/phass/phass.cpp
82+
unwrap/snaphu/mcf.cpp
8283
)
8384

8485
#This is a temporary fix - since GDAL does not support

0 commit comments

Comments
 (0)