Skip to content

Commit d805427

Browse files
committed
rework fuse and iteration
1 parent 24bd124 commit d805427

File tree

6 files changed

+219
-141
lines changed

6 files changed

+219
-141
lines changed

source/mir/algorithm/iteration.d

Lines changed: 26 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -551,21 +551,11 @@ private void checkShapesMatch(
551551
{
552552
import mir.ndslice.fuse: fuseShape;
553553
static assert(slices[i].fuseShape.length >= N);
554-
assert(slices[i].fuseShape[0 .. N] == slices[0].shape, msgShape);
554+
assert(cast(size_t[N])slices[i].fuseShape[0 .. N] == slices[0].shape, msgShape);
555555
}
556556
}
557557
}
558558

559-
template frontOf(size_t N)
560-
{
561-
static if (N == 0)
562-
enum frontOf = "";
563-
else
564-
{
565-
enum i = N - 1;
566-
enum frontOf = frontOf!i ~ "slices[" ~ i.stringof ~ "].front, ";
567-
}
568-
}
569559

570560
package(mir) template allFlattened(args...)
571561
{
@@ -587,11 +577,19 @@ private template areAllContiguousSlices(Slices...)
587577
{
588578
import mir.ndslice.traits: isContiguousSlice;
589579
static if (allSatisfy!(isContiguousSlice, Slices))
590-
enum areAllContiguousSlices = Slices[0].N > 1;
580+
enum areAllContiguousSlices = Slices[0].N > 1 && areAllContiguousSlicesImpl!(Slices[0].N, Slices[1 .. $]);
591581
else
592582
enum areAllContiguousSlices = false;
593583
}
594584

585+
private template areAllContiguousSlicesImpl(size_t N, Slices...)
586+
{
587+
static if (Slices.length == 0)
588+
enum areAllContiguousSlicesImpl = true;
589+
else
590+
enum areAllContiguousSlicesImpl = Slices[0].N == N && areAllContiguousSlicesImpl!(N, Slices[1 .. $]);
591+
}
592+
595593
version(LDC) {}
596594
else version(GNU) {}
597595
else version (Windows) {}
@@ -663,9 +661,9 @@ S reduceImpl(alias fun, S, Slices...)(S seed, scope Slices slices)
663661
do
664662
{
665663
static if (DimensionCount!(Slices[0]) == 1)
666-
seed = mixin("fun(seed, " ~ frontOf!(Slices.length) ~ ")");
664+
seed = fun(seed, frontOf!slices);
667665
else
668-
seed = mixin(".reduceImpl!fun(seed," ~ frontOf!(Slices.length) ~ ")");
666+
seed = .reduceImpl!fun(seed, frontOf!slices);
669667
foreach_reverse(ref slice; slices)
670668
slice.popFront;
671669
}
@@ -920,9 +918,9 @@ void eachImpl(alias fun, Slices...)(scope Slices slices)
920918
do
921919
{
922920
static if (DimensionCount!(Slices[0]) == 1)
923-
mixin("fun(" ~ frontOf!(Slices.length) ~ ");");
921+
fun(frontOf!slices);
924922
else
925-
mixin(".eachImpl!fun(" ~ frontOf!(Slices.length) ~ ");");
923+
.eachImpl!fun(frontOf!slices);
926924
foreach_reverse(i; Iota!(Slices.length))
927925
slices[i].popFront;
928926
}
@@ -1678,15 +1676,15 @@ bool findImpl(alias fun, size_t N, Slices...)(scope ref size_t[N] backwardIndex,
16781676
{
16791677
static if (DimensionCount!(Slices[0]) == 1)
16801678
{
1681-
if (mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
1679+
if (fun(frontOf!slices))
16821680
{
16831681
backwardIndex[0] = slices[0].length;
16841682
return true;
16851683
}
16861684
}
16871685
else
16881686
{
1689-
if (mixin("findImpl!fun(backwardIndex[1 .. $], " ~ frontOf!(Slices.length) ~ ")"))
1687+
if (findImpl!fun(backwardIndex[1 .. $], frontOf!slices))
16901688
{
16911689
backwardIndex[0] = slices[0].length;
16921690
return true;
@@ -1965,12 +1963,12 @@ size_t anyImpl(alias fun, Slices...)(scope Slices slices)
19651963
{
19661964
static if (DimensionCount!(Slices[0]) == 1)
19671965
{
1968-
if (mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
1966+
if (fun(frontOf!slices))
19691967
return true;
19701968
}
19711969
else
19721970
{
1973-
if (mixin("anyImpl!fun(" ~ frontOf!(Slices.length) ~ ")"))
1971+
if (anyImpl!fun(frontOf!slices))
19741972
return true;
19751973
}
19761974
foreach_reverse(ref slice; slices)
@@ -2125,12 +2123,12 @@ size_t allImpl(alias fun, Slices...)(scope Slices slices)
21252123
{
21262124
static if (DimensionCount!(Slices[0]) == 1)
21272125
{
2128-
if (!mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
2126+
if (!fun(frontOf!slices))
21292127
return false;
21302128
}
21312129
else
21322130
{
2133-
if (!mixin("allImpl!fun(" ~ frontOf!(Slices.length) ~ ")"))
2131+
if (!allImpl!fun(frontOf!slices))
21342132
return false;
21352133
}
21362134
foreach_reverse(ref slice; slices)
@@ -2636,42 +2634,18 @@ size_t countImpl(alias fun, Slices...)(scope Slices slices)
26362634
{
26372635
static if (DimensionCount!(Slices[0]) == 1)
26382636
{
2639-
if(mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
2637+
if(fun(frontOf!slices))
26402638
ret++;
26412639
}
26422640
else
2643-
ret += mixin(".countImpl!fun(" ~ frontOf!(Slices.length) ~ ")");
2641+
ret += .countImpl!fun(frontOf!slices);
26442642
foreach_reverse(ref slice; slices)
26452643
slice.popFront;
26462644
}
26472645
while(!slices[0].empty);
26482646
return ret;
26492647
}
26502648

2651-
private template selectBackOf(size_t N, string input)
2652-
{
2653-
static if (N == 0)
2654-
enum selectBackOf = "";
2655-
else
2656-
{
2657-
enum i = N - 1;
2658-
enum selectBackOf = selectBackOf!(i, input) ~
2659-
"lightScope(slices[" ~ i.stringof ~ "]).selectBack!0(" ~ input ~ "), ";
2660-
}
2661-
}
2662-
2663-
private template frontSelectFrontOf(size_t N, string input)
2664-
{
2665-
static if (N == 0)
2666-
enum frontSelectFrontOf = "";
2667-
else
2668-
{
2669-
enum i = N - 1;
2670-
enum frontSelectFrontOf = frontSelectFrontOf!(i, input) ~
2671-
"lightScope(slices[" ~ i.stringof ~ "]).front!0.selectFront!0(" ~ input ~ "), ";
2672-
}
2673-
}
2674-
26752649
/++
26762650
Returns: max length across all dimensions.
26772651
+/
@@ -2756,7 +2730,7 @@ template eachLower(alias fun)
27562730
if ((n + k) < m)
27572731
{
27582732
val = m - (n + k);
2759-
mixin(".eachImpl!fun(" ~ selectBackOf!(Slices.length, "val") ~ ");");
2733+
.eachImpl!fun(selectBackOf!(val, slices));
27602734
}
27612735

27622736
size_t i;
@@ -2771,7 +2745,7 @@ template eachLower(alias fun)
27712745
do
27722746
{
27732747
val = i - k + 1;
2774-
mixin(".eachImpl!fun(" ~ frontSelectFrontOf!(Slices.length, "val") ~ ");");
2748+
.eachImpl!fun(frontSelectFrontOf!(val, slices));
27752749

27762750
foreach(ref slice; slices)
27772751
slice.popFront!0;
@@ -3127,30 +3101,6 @@ version(mir_test) unittest
31273101
[ 6, 7, 18]]);
31283102
}
31293103

3130-
private template frontSelectBackOf(size_t N, string input)
3131-
{
3132-
static if (N == 0)
3133-
enum frontSelectBackOf = "";
3134-
else
3135-
{
3136-
enum i = N - 1;
3137-
enum frontSelectBackOf = frontSelectBackOf!(i, input) ~
3138-
"lightScope(slices[" ~ i.stringof ~ "]).front.selectBack!0(" ~ input ~ "), ";
3139-
}
3140-
}
3141-
3142-
private template selectFrontOf(size_t N, string input)
3143-
{
3144-
static if (N == 0)
3145-
enum selectFrontOf = "";
3146-
else
3147-
{
3148-
enum i = N - 1;
3149-
enum selectFrontOf = selectFrontOf!(i, input) ~
3150-
"lightScope(slices[" ~ i.stringof ~ "]).selectFront!0(" ~ input ~ "), ";
3151-
}
3152-
}
3153-
31543104
/++
31553105
The call `eachUpper!(fun)(slice1, ..., sliceN)` evaluates `fun` on the upper
31563106
triangle in `slice1, ..., sliceN`, respectively.
@@ -3223,7 +3173,7 @@ template eachUpper(alias fun)
32233173
if (k < 0)
32243174
{
32253175
val = -k;
3226-
mixin(".eachImpl!fun(" ~ selectFrontOf!(Slices.length, "val") ~ ");");
3176+
.eachImpl!fun(selectFrontOf!(val, slices));
32273177

32283178
foreach(ref slice; slices)
32293179
slice.popFrontExactly!0(-k);
@@ -3233,7 +3183,7 @@ template eachUpper(alias fun)
32333183
do
32343184
{
32353185
val = (n - k) - i;
3236-
mixin(".eachImpl!fun(" ~ frontSelectBackOf!(Slices.length, "val") ~ ");");
3186+
.eachImpl!fun(frontSelectBackOf!(val, slices));
32373187

32383188
foreach(ref slice; slices)
32393189
slice.popFront;

source/mir/combinatorics/package.d

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ struct IndexedRoR(Collection, Range)
217217
private Collection c;
218218
private Range r;
219219

220+
///
221+
alias DeepElement = ForeachType!Range;
222+
220223
///
221224
this()(Collection collection, Range range)
222225
{
@@ -243,7 +246,7 @@ struct IndexedRoR(Collection, Range)
243246
}
244247

245248
/// Input range primitives
246-
auto ref front()() @property
249+
auto front()() @property
247250
{
248251
import mir.ndslice.slice: isSlice, sliced;
249252
import mir.ndslice.topology: indexed;
@@ -409,6 +412,9 @@ struct Permutations(T)
409412
private bool _empty;
410413
private size_t _max_states = 1, _pos;
411414

415+
///
416+
alias DeepElement = const T;
417+
412418
/**
413419
state should have the length of `n - 1`,
414420
whereas the length of indices should be `n`
@@ -642,12 +648,12 @@ struct CartesianPower(T)
642648
{
643649
import mir.ndslice.slice: Slice;
644650

645-
private:
646-
T[] _state;
647-
size_t n;
648-
size_t _max_states, _pos;
651+
private T[] _state;
652+
private size_t n;
653+
private size_t _max_states, _pos;
649654

650-
public:
655+
///
656+
alias DeepElement = const T;
651657

652658
/// state should have the length of `repeat`
653659
this()(size_t n, T[] state) @safe pure nothrow @nogc
@@ -893,12 +899,12 @@ struct Combinations(T)
893899
{
894900
import mir.ndslice.slice: Slice;
895901

896-
private:
897-
T[] state;
898-
size_t n;
899-
size_t max_states, pos;
902+
private T[] state;
903+
private size_t n;
904+
private size_t max_states, pos;
900905

901-
public:
906+
///
907+
alias DeepElement = const T;
902908

903909
/// state should have the length of `repeat`
904910
this()(size_t n, T[] state) @safe pure nothrow @nogc
@@ -1221,12 +1227,12 @@ struct CombinationsRepeat(T)
12211227
{
12221228
import mir.ndslice.slice: Slice;
12231229

1224-
private:
1225-
T[] state;
1226-
size_t n;
1227-
size_t max_states, pos;
1230+
private T[] state;
1231+
private size_t n;
1232+
private size_t max_states, pos;
12281233

1229-
public:
1234+
///
1235+
alias DeepElement = const T;
12301236

12311237
/// state should have the length of `repeat`
12321238
this()(size_t n, T[] state) @safe pure nothrow @nogc

source/mir/ndslice/concatenation.d

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -211,28 +211,6 @@ version(mir_test) unittest
211211
assert(s.slicedNdField == s.length.iota);
212212
}
213213

214-
template frontOf(size_t N)
215-
{
216-
static if (N == 0)
217-
enum frontOf = "";
218-
else
219-
{
220-
enum i = N - 1;
221-
enum frontOf = frontOf!i ~ "slices[" ~ i.stringof ~ "].front!d, ";
222-
}
223-
}
224-
225-
template frontOfSt(size_t N)
226-
{
227-
static if (N == 0)
228-
enum frontOfSt = "";
229-
else
230-
{
231-
enum i = N - 1;
232-
enum frontOfSt = frontOfSt!i ~ "st._slices[" ~ i.stringof ~ "].front!d, ";
233-
}
234-
}
235-
236214
///
237215
enum bool isConcatenation(T) = is(T : Concatenation!(dim, Slices), size_t dim, Slices...);
238216
///
@@ -359,9 +337,9 @@ struct Concatenation(size_t dim, Slices...)
359337
}
360338
else
361339
{
340+
import mir.ndslice.internal: frontOfDim;
362341
enum elemDim = d < dim ? dim - 1 : dim;
363-
alias slices = _slices;
364-
return mixin(`concatenation!elemDim(` ~ frontOf!(Slices.length) ~ `)`);
342+
return concatenation!elemDim(frontOfDim!(d, _slices));
365343
}
366344
}
367345

@@ -402,8 +380,10 @@ auto applyFront(size_t d = 0, alias fun, size_t dim, Slices...)(Concatenation!(d
402380
}
403381
else
404382
{
383+
import mir.ndslice.internal: frontOfDim;
405384
enum elemDim = d < dim ? dim - 1 : dim;
406-
return fun(mixin(`concatenation!elemDim(` ~ frontOfSt!(Slices.length) ~ `)`));
385+
auto slices = st._slices;
386+
return fun(concatenation!elemDim(frontOfDim!(d, slices)));
407387
}
408388
}
409389

0 commit comments

Comments
 (0)