Skip to content

Commit 7ce634c

Browse files
Biotronic0xEAB
authored andcommitted
Fix #10380 - Memoize should handle lambdas
1 parent 0124e47 commit 7ce634c

File tree

1 file changed

+142
-60
lines changed

1 file changed

+142
-60
lines changed

std/functional.d

Lines changed: 142 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,15 @@ alias pipe(fun...) = compose!(Reverse!(fun));
12891289
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
12901290
}
12911291

1292+
private template getOverloads(alias fun)
1293+
{
1294+
import std.meta : AliasSeq;
1295+
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun))))
1296+
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun));
1297+
else
1298+
alias getOverloads = AliasSeq!fun;
1299+
}
1300+
12921301
/**
12931302
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
12941303
* to avoid repeated computation. The memoization structure is a hash table keyed by a
@@ -1324,87 +1333,129 @@ Note:
13241333
*/
13251334
template memoize(alias fun)
13261335
{
1327-
import std.traits : ReturnType;
1328-
// https://issues.dlang.org/show_bug.cgi?id=13580
1329-
// alias Args = Parameters!fun;
1336+
import std.traits : Parameters;
1337+
import std.meta : anySatisfy;
1338+
1339+
// Specific overloads:
1340+
alias overloads = getOverloads!fun;
1341+
static foreach (fn; overloads)
1342+
static if (is(Parameters!fn))
1343+
alias memoize = impl!(Parameters!fn);
1344+
1345+
enum isTemplate(alias a) = __traits(isTemplate, a);
1346+
static if (anySatisfy!(isTemplate, overloads))
1347+
{
1348+
// Generic implementation
1349+
alias memoize = impl;
1350+
}
13301351

1331-
ReturnType!fun memoize(Parameters!fun args)
1352+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
13321353
{
1333-
alias Args = Parameters!fun;
1334-
import std.typecons : Tuple;
1354+
import std.typecons : Tuple, tuple;
13351355
import std.traits : Unqual;
13361356

1337-
static Unqual!(ReturnType!fun)[Tuple!Args] memo;
1338-
auto t = Tuple!Args(args);
1339-
if (auto p = t in memo)
1340-
return *p;
1341-
auto r = fun(args);
1342-
memo[t] = r;
1343-
return r;
1357+
static if (args.length > 0)
1358+
{
1359+
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;
1360+
1361+
auto t = Tuple!Args(args);
1362+
if (auto p = t in memo)
1363+
return *p;
1364+
auto r = fun(args);
1365+
memo[t] = r;
1366+
return r;
1367+
}
1368+
else
1369+
{
1370+
static typeof(fun(args)) result = fun(args);
1371+
return result;
1372+
}
13441373
}
13451374
}
13461375

13471376
/// ditto
13481377
template memoize(alias fun, uint maxSize)
13491378
{
1350-
import std.traits : ReturnType;
1351-
// https://issues.dlang.org/show_bug.cgi?id=13580
1352-
// alias Args = Parameters!fun;
1353-
ReturnType!fun memoize(Parameters!fun args)
1379+
import std.traits : Parameters;
1380+
import std.meta : anySatisfy;
1381+
1382+
// Specific overloads:
1383+
alias overloads = getOverloads!fun;
1384+
static foreach (fn; overloads)
1385+
static if (is(Parameters!fn))
1386+
alias memoize = impl!(Parameters!fn);
1387+
1388+
enum isTemplate(alias a) = __traits(isTemplate, a);
1389+
static if (anySatisfy!(isTemplate, overloads))
13541390
{
1355-
import std.meta : staticMap;
1356-
import std.traits : hasIndirections, Unqual;
1357-
import std.typecons : tuple;
1358-
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
1359-
static Value[] memo;
1360-
static size_t[] initialized;
1391+
// Generic implementation
1392+
alias memoize = impl;
1393+
}
13611394

1362-
if (!memo.length)
1395+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
1396+
{
1397+
static if (args.length > 0)
13631398
{
1364-
import core.memory : GC;
1399+
import std.meta : staticMap;
1400+
import std.traits : hasIndirections, Unqual;
1401+
import std.typecons : tuple;
1402+
alias returnType = typeof(fun(args));
1403+
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
1404+
static Value[] memo;
1405+
static size_t[] initialized;
13651406

1366-
// Ensure no allocation overflows
1367-
static assert(maxSize < size_t.max / Value.sizeof);
1368-
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
1407+
if (!memo.length)
1408+
{
1409+
import core.memory : GC;
13691410

1370-
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1371-
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1372-
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1373-
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1374-
}
1411+
// Ensure no allocation overflows
1412+
static assert(maxSize < size_t.max / Value.sizeof);
1413+
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
13751414

1376-
import core.bitop : bt, bts;
1377-
import core.lifetime : emplace;
1415+
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1416+
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1417+
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1418+
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1419+
}
13781420

1379-
size_t hash;
1380-
foreach (ref arg; args)
1381-
hash = hashOf(arg, hash);
1382-
// cuckoo hashing
1383-
immutable idx1 = hash % maxSize;
1384-
if (!bt(initialized.ptr, idx1))
1385-
{
1386-
emplace(&memo[idx1], args, fun(args));
1387-
// only set to initialized after setting args and value
1388-
// https://issues.dlang.org/show_bug.cgi?id=14025
1389-
bts(initialized.ptr, idx1);
1421+
import core.bitop : bt, bts;
1422+
import core.lifetime : emplace;
1423+
1424+
size_t hash;
1425+
foreach (ref arg; args)
1426+
hash = hashOf(arg, hash);
1427+
// cuckoo hashing
1428+
immutable idx1 = hash % maxSize;
1429+
if (!bt(initialized.ptr, idx1))
1430+
{
1431+
emplace(&memo[idx1], args, fun(args));
1432+
// only set to initialized after setting args and value
1433+
// https://issues.dlang.org/show_bug.cgi?id=14025
1434+
bts(initialized.ptr, idx1);
1435+
return memo[idx1].res;
1436+
}
1437+
else if (memo[idx1].args == args)
1438+
return memo[idx1].res;
1439+
// FNV prime
1440+
immutable idx2 = (hash * 16_777_619) % maxSize;
1441+
if (!bt(initialized.ptr, idx2))
1442+
{
1443+
emplace(&memo[idx2], memo[idx1]);
1444+
bts(initialized.ptr, idx2);
1445+
}
1446+
else if (memo[idx2].args == args)
1447+
return memo[idx2].res;
1448+
else if (idx1 != idx2)
1449+
memo[idx2] = memo[idx1];
1450+
1451+
memo[idx1] = Value(args, fun(args));
13901452
return memo[idx1].res;
13911453
}
1392-
else if (memo[idx1].args == args)
1393-
return memo[idx1].res;
1394-
// FNV prime
1395-
immutable idx2 = (hash * 16_777_619) % maxSize;
1396-
if (!bt(initialized.ptr, idx2))
1454+
else
13971455
{
1398-
emplace(&memo[idx2], memo[idx1]);
1399-
bts(initialized.ptr, idx2);
1456+
static typeof(fun(args)) result = fun(args);
1457+
return result;
14001458
}
1401-
else if (memo[idx2].args == args)
1402-
return memo[idx2].res;
1403-
else if (idx1 != idx2)
1404-
memo[idx2] = memo[idx1];
1405-
1406-
memo[idx1] = Value(args, fun(args));
1407-
return memo[idx1].res;
14081459
}
14091460
}
14101461

@@ -1464,6 +1515,37 @@ unittest
14641515
assert(fact(10) == 3628800);
14651516
}
14661517

1518+
// Issue 20099
1519+
@system unittest // not @safe due to memoize
1520+
{
1521+
int i = 3;
1522+
alias a = memoize!((n) => i + n);
1523+
alias b = memoize!((n) => i + n, 3);
1524+
1525+
assert(a(3) == 6);
1526+
assert(b(3) == 6);
1527+
}
1528+
1529+
@system unittest // not @safe due to memoize
1530+
{
1531+
static Object objNum(int a) { return new Object(); }
1532+
assert(memoize!objNum(0) is memoize!objNum(0U));
1533+
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
1534+
}
1535+
1536+
@system unittest // not @safe due to memoize
1537+
{
1538+
struct S
1539+
{
1540+
static int fun() { return 0; }
1541+
static int fun(int i) { return 1; }
1542+
}
1543+
assert(memoize!(S.fun)() == 0);
1544+
assert(memoize!(S.fun)(3) == 1);
1545+
assert(memoize!(S.fun, 3)() == 0);
1546+
assert(memoize!(S.fun, 3)(3) == 1);
1547+
}
1548+
14671549
@system unittest // not @safe due to memoize
14681550
{
14691551
import core.math : sqrt;

0 commit comments

Comments
 (0)