Skip to content

Commit 11dba30

Browse files
committed
askrene: refactor MCF
Refactor MCF solver: remove structs linear_network and residual_network. Prefer passing raw data to the helper functions. Changelog-None Signed-off-by: Lagrang3 <[email protected]>
1 parent e4e79f0 commit 11dba30

File tree

1 file changed

+66
-140
lines changed

1 file changed

+66
-140
lines changed

plugins/askrene/mcf.c

Lines changed: 66 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -299,48 +299,17 @@ struct pay_parameters {
299299
double base_fee_penalty;
300300
};
301301

302-
/* Representation of the linear MCF network.
303-
* This contains the topology of the extended network (after linearization and
304-
* addition of arc duality).
305-
* This contains also the arc probability and linear fee cost, as well as
306-
* capacity; these quantities remain constant during MCF execution. */
307-
struct linear_network
308-
{
309-
struct graph *graph;
310-
311-
// probability and fee cost associated to an arc
312-
double *arc_prob_cost;
313-
s64 *arc_fee_cost;
314-
s64 *capacity;
315-
};
316-
317-
/* This is the structure that keeps track of the network properties while we
318-
* seek for a solution. */
319-
struct residual_network {
320-
/* residual capacity on arcs */
321-
s64 *cap;
322-
323-
/* some combination of prob. cost and fee cost on arcs */
324-
s64 *cost;
325-
326-
/* potential function on nodes */
327-
s64 *potential;
328-
329-
/* auxiliary data, the excess of flow on nodes */
330-
s64 *excess;
331-
};
332-
333302
/* Helper function.
334303
* Given an arc of the network (not residual) give me the flow. */
335304
static s64 get_arc_flow(
336-
const struct residual_network *network,
305+
const s64 *arc_residual_capacity,
337306
const struct graph *graph,
338307
const struct arc arc)
339308
{
340309
assert(!arc_is_dual(graph, arc));
341310
struct arc dual = arc_dual(graph, arc);
342-
assert(dual.idx < tal_count(network->cap));
343-
return network->cap[dual.idx];
311+
assert(dual.idx < tal_count(arc_residual_capacity));
312+
return arc_residual_capacity[dual.idx];
344313
}
345314

346315
/* Set *capacity to value, up to *cap_on_capacity. Reduce cap_on_capacity */
@@ -385,49 +354,6 @@ static void linearize_channel(const struct pay_parameters *params,
385354
}
386355
}
387356

388-
static struct residual_network *
389-
alloc_residual_network(const tal_t *ctx, const size_t max_num_nodes,
390-
const size_t max_num_arcs)
391-
{
392-
struct residual_network *residual_network =
393-
tal(ctx, struct residual_network);
394-
395-
residual_network->cap = tal_arrz(residual_network, s64, max_num_arcs);
396-
residual_network->cost = tal_arrz(residual_network, s64, max_num_arcs);
397-
residual_network->potential =
398-
tal_arrz(residual_network, s64, max_num_nodes);
399-
residual_network->excess =
400-
tal_arrz(residual_network, s64, max_num_nodes);
401-
402-
return residual_network;
403-
}
404-
405-
static void init_residual_network(
406-
const struct linear_network * linear_network,
407-
struct residual_network* residual_network)
408-
{
409-
const struct graph *graph = linear_network->graph;
410-
const size_t max_num_arcs = graph_max_num_arcs(graph);
411-
const size_t max_num_nodes = graph_max_num_nodes(graph);
412-
413-
for (struct arc arc = {.idx = 0}; arc.idx < max_num_arcs; ++arc.idx) {
414-
if (arc_is_dual(graph, arc) || !arc_enabled(graph, arc))
415-
continue;
416-
417-
struct arc dual = arc_dual(graph, arc);
418-
residual_network->cap[arc.idx] =
419-
linear_network->capacity[arc.idx];
420-
residual_network->cap[dual.idx] = 0;
421-
422-
residual_network->cost[arc.idx] =
423-
residual_network->cost[dual.idx] = 0;
424-
}
425-
for (u32 i = 0; i < max_num_nodes; ++i) {
426-
residual_network->potential[i] = 0;
427-
residual_network->excess[i] = 0;
428-
}
429-
}
430-
431357
static int cmp_u64(const u64 *a, const u64 *b, void *unused)
432358
{
433359
if (*a < *b)
@@ -447,9 +373,10 @@ static int cmp_double(const double *a, const double *b, void *unused)
447373
}
448374

449375
static double get_median_ratio(const tal_t *working_ctx,
450-
const struct linear_network* linear_network)
376+
const struct graph *graph,
377+
const double *arc_prob_cost,
378+
const s64 *arc_fee_cost)
451379
{
452-
const struct graph *graph = linear_network->graph;
453380
const size_t max_num_arcs = graph_max_num_arcs(graph);
454381
u64 *u64_arr = tal_arr(working_ctx, u64, max_num_arcs);
455382
double *double_arr = tal_arr(working_ctx, double, max_num_arcs);
@@ -460,8 +387,8 @@ static double get_median_ratio(const tal_t *working_ctx,
460387
if (arc_is_dual(graph, arc) || !arc_enabled(graph, arc))
461388
continue;
462389
assert(n < max_num_arcs/2);
463-
u64_arr[n] = linear_network->arc_fee_cost[arc.idx];
464-
double_arr[n] = linear_network->arc_prob_cost[arc.idx];
390+
u64_arr[n] = arc_fee_cost[arc.idx];
391+
double_arr[n] = arc_prob_cost[arc.idx];
465392
n++;
466393
}
467394
asort(u64_arr, n, cmp_u64, NULL);
@@ -475,27 +402,26 @@ static double get_median_ratio(const tal_t *working_ctx,
475402
return u64_arr[n/2] / double_arr[n/2];
476403
}
477404

478-
static void combine_cost_function(
479-
const tal_t *working_ctx,
480-
const struct linear_network* linear_network,
481-
struct residual_network *residual_network,
482-
const s8 *biases,
483-
s64 mu)
405+
static void combine_cost_function(const tal_t *working_ctx,
406+
const struct graph *graph,
407+
const double *arc_prob_cost,
408+
const s64 *arc_fee_cost, const s8 *biases,
409+
s64 mu, s64 *arc_cost)
484410
{
485411
/* probabilty and fee costs are not directly comparable!
486412
* Scale by ratio of (positive) medians. */
487-
const double k = get_median_ratio(working_ctx, linear_network);
413+
const double k =
414+
get_median_ratio(working_ctx, graph, arc_prob_cost, arc_fee_cost);
488415
const double ln_30 = log(30);
489-
const struct graph *graph = linear_network->graph;
490416
const size_t max_num_arcs = graph_max_num_arcs(graph);
491417

492418
for(struct arc arc = {.idx=0};arc.idx < max_num_arcs; ++arc.idx)
493419
{
494420
if (arc_is_dual(graph, arc) || !arc_enabled(graph, arc))
495421
continue;
496422

497-
const double pcost = linear_network->arc_prob_cost[arc.idx];
498-
const s64 fcost = linear_network->arc_fee_cost[arc.idx];
423+
const double pcost = arc_prob_cost[arc.idx];
424+
const s64 fcost = arc_fee_cost[arc.idx];
499425
double combined;
500426
u32 chanidx;
501427
int chandir;
@@ -515,13 +441,13 @@ static void combine_cost_function(
515441
* e^(-bias / (100/ln(30)))
516442
*/
517443
double bias_factor = exp(-bias / (100 / ln_30));
518-
residual_network->cost[arc.idx] = combined * bias_factor;
444+
arc_cost[arc.idx] = combined * bias_factor;
519445
} else {
520-
residual_network->cost[arc.idx] = combined;
446+
arc_cost[arc.idx] = combined;
521447
}
522448
/* and the respective dual */
523449
struct arc dual = arc_dual(graph, arc);
524-
residual_network->cost[dual.idx] = -combined;
450+
arc_cost[dual.idx] = -combined;
525451
}
526452
}
527453

@@ -578,31 +504,26 @@ struct amount_msat linear_flow_cost(const struct flow *flow,
578504
return msat_cost;
579505
}
580506

581-
/* FIXME: Instead of mapping one-to-one the indexes in the gossmap, try to
582-
* reduce the number of nodes and arcs used by taking only those that are
583-
* enabled. We might save some cpu if the work with a pruned network. */
584-
static struct linear_network *
585-
init_linear_network(const tal_t *ctx, const struct pay_parameters *params)
507+
static void init_linear_network(const tal_t *ctx,
508+
const struct pay_parameters *params,
509+
struct graph **graph, double **arc_prob_cost,
510+
s64 **arc_fee_cost, s64 **arc_capacity)
586511
{
587-
struct linear_network * linear_network = tal(ctx, struct linear_network);
588512
const struct gossmap *gossmap = params->rq->gossmap;
589-
590513
const size_t max_num_chans = gossmap_max_chan_idx(gossmap);
591514
const size_t max_num_arcs = max_num_chans * ARCS_PER_CHANNEL;
592515
const size_t max_num_nodes = gossmap_max_node_idx(gossmap);
593516

594-
linear_network->graph =
595-
graph_new(ctx, max_num_nodes, max_num_arcs, ARC_DUAL_BITOFF);
517+
*graph = graph_new(ctx, max_num_nodes, max_num_arcs, ARC_DUAL_BITOFF);
518+
*arc_prob_cost = tal_arr(ctx, double, max_num_arcs);
519+
for (size_t i = 0; i < max_num_arcs; ++i)
520+
(*arc_prob_cost)[i] = DBL_MAX;
596521

597-
linear_network->arc_prob_cost = tal_arr(linear_network,double,max_num_arcs);
598-
for(size_t i=0;i<max_num_arcs;++i)
599-
linear_network->arc_prob_cost[i]=DBL_MAX;
522+
*arc_fee_cost = tal_arr(ctx, s64, max_num_arcs);
523+
for (size_t i = 0; i < max_num_arcs; ++i)
524+
(*arc_fee_cost)[i] = INT64_MAX;
600525

601-
linear_network->arc_fee_cost = tal_arr(linear_network,s64,max_num_arcs);
602-
for(size_t i=0;i<max_num_arcs;++i)
603-
linear_network->arc_fee_cost[i]=INFINITE;
604-
605-
linear_network->capacity = tal_arrz(linear_network,s64,max_num_arcs);
526+
*arc_capacity = tal_arrz(ctx, s64, max_num_arcs);
606527

607528
for(struct gossmap_node *node = gossmap_first_node(gossmap);
608529
node;
@@ -660,25 +581,23 @@ init_linear_network(const tal_t *ctx, const struct pay_parameters *params)
660581

661582
struct arc arc = arc_from_parts(chan_id, half, k, false);
662583

663-
graph_add_arc(linear_network->graph, arc,
584+
graph_add_arc(*graph, arc,
664585
node_obj(node_id),
665586
node_obj(next_id));
666587

667-
linear_network->capacity[arc.idx] = capacity[k];
668-
linear_network->arc_prob_cost[arc.idx] = prob_cost[k];
669-
linear_network->arc_fee_cost[arc.idx] = fee_cost;
588+
(*arc_capacity)[arc.idx] = capacity[k];
589+
(*arc_prob_cost)[arc.idx] = prob_cost[k];
590+
(*arc_fee_cost)[arc.idx] = fee_cost;
670591

671592
// + the respective dual
672-
struct arc dual = arc_dual(linear_network->graph, arc);
593+
struct arc dual = arc_dual(*graph, arc);
673594

674-
linear_network->capacity[dual.idx] = 0;
675-
linear_network->arc_prob_cost[dual.idx] = -prob_cost[k];
676-
linear_network->arc_fee_cost[dual.idx] = -fee_cost;
595+
(*arc_capacity)[dual.idx] = 0;
596+
(*arc_prob_cost)[dual.idx] = -prob_cost[k];
597+
(*arc_fee_cost)[dual.idx] = -fee_cost;
677598
}
678599
}
679600
}
680-
681-
return linear_network;
682601
}
683602

684603
// flow on directed channels
@@ -873,8 +792,8 @@ static struct flow **
873792
get_flow_paths(const tal_t *ctx,
874793
const tal_t *working_ctx,
875794
const struct pay_parameters *params,
876-
const struct linear_network *linear_network,
877-
const struct residual_network *residual_network)
795+
const struct graph *graph,
796+
const s64 *arc_residual_capacity)
878797
{
879798
struct flow **flows = tal_arr(ctx,struct flow*,0);
880799

@@ -897,7 +816,6 @@ get_flow_paths(const tal_t *ctx,
897816
// Convert the arc based residual network flow into a flow in the
898817
// directed channel network.
899818
// Compute balance on the nodes.
900-
const struct graph *graph = linear_network->graph;
901819
for (struct node n = {.idx = 0}; n.idx < max_num_nodes; n.idx++) {
902820
for(struct arc arc = node_adjacency_begin(graph,n);
903821
!node_adjacency_end(arc);
@@ -906,7 +824,7 @@ get_flow_paths(const tal_t *ctx,
906824
if(arc_is_dual(graph, arc))
907825
continue;
908826
struct node m = arc_head(graph,arc);
909-
s64 flow = get_arc_flow(residual_network,
827+
s64 flow = get_arc_flow(arc_residual_capacity,
910828
graph, arc);
911829
u32 chanidx;
912830
int chandir;
@@ -1000,40 +918,48 @@ struct flow **minflow(const tal_t *ctx,
1000918
params->base_fee_penalty = base_fee_penalty_estimate(amount);
1001919

1002920
// build the uncertainty network with linearization and residual arcs
1003-
struct linear_network *linear_network= init_linear_network(working_ctx, params);
1004-
const struct graph *graph = linear_network->graph;
921+
struct graph *graph;
922+
double *arc_prob_cost;
923+
s64 *arc_fee_cost;
924+
s64 *arc_capacity;
925+
init_linear_network(working_ctx, params, &graph, &arc_prob_cost,
926+
&arc_fee_cost, &arc_capacity);
927+
1005928
const size_t max_num_arcs = graph_max_num_arcs(graph);
1006929
const size_t max_num_nodes = graph_max_num_nodes(graph);
1007-
struct residual_network *residual_network =
1008-
alloc_residual_network(working_ctx, max_num_nodes, max_num_arcs);
930+
s64 *arc_cost;
931+
s64 *node_potential;
932+
s64 *node_excess;
933+
arc_cost = tal_arrz(working_ctx, s64, max_num_arcs);
934+
node_potential = tal_arrz(working_ctx, s64, max_num_nodes);
935+
node_excess = tal_arrz(working_ctx, s64, max_num_nodes);
1009936

1010937
const struct node dst = {.idx = gossmap_node_idx(rq->gossmap, target)};
1011938
const struct node src = {.idx = gossmap_node_idx(rq->gossmap, source)};
1012939

1013-
init_residual_network(linear_network,residual_network);
1014940

1015941
/* Since we have constraint accuracy, ask to find a payment solution
1016942
* that can pay a bit more than the actual value rathen than undershoot it.
1017943
* That's why we use the ceil function here. */
1018944
const u64 pay_amount =
1019945
amount_msat_ratio_ceil(params->amount, params->accuracy);
1020946

1021-
if (!simple_feasibleflow(working_ctx, linear_network->graph, src, dst,
1022-
residual_network->cap, pay_amount)) {
947+
if (!simple_feasibleflow(working_ctx, graph, src, dst,
948+
arc_capacity, pay_amount)) {
1023949
rq_log(tmpctx, rq, LOG_INFORM,
1024950
"%s failed: unable to find a feasible flow.", __func__);
1025951
goto fail;
1026952
}
1027-
combine_cost_function(working_ctx, linear_network, residual_network,
1028-
rq->biases, mu);
953+
combine_cost_function(working_ctx, graph, arc_prob_cost, arc_fee_cost,
954+
rq->biases, mu, arc_cost);
1029955

1030956
/* We solve a linear MCF problem. */
1031957
if (!mcf_refinement(working_ctx,
1032-
linear_network->graph,
1033-
residual_network->excess,
1034-
residual_network->cap,
1035-
residual_network->cost,
1036-
residual_network->potential)) {
958+
graph,
959+
node_excess,
960+
arc_capacity,
961+
arc_cost,
962+
node_potential)) {
1037963
rq_log(tmpctx, rq, LOG_BROKEN,
1038964
"%s: MCF optimization step failed", __func__);
1039965
goto fail;
@@ -1043,7 +969,7 @@ struct flow **minflow(const tal_t *ctx,
1043969
* Actual amounts considering fees are computed for every
1044970
* channel in the routes. */
1045971
flow_paths = get_flow_paths(ctx, working_ctx, params,
1046-
linear_network, residual_network);
972+
graph, arc_capacity);
1047973
if(!flow_paths){
1048974
rq_log(tmpctx, rq, LOG_BROKEN,
1049975
"%s: failed to extract flow paths from the MCF solution",

0 commit comments

Comments
 (0)