Skip to content

Commit 9f7c928

Browse files
Biotronicn8sh
authored andcommitted
Fix issue 20099 - Memoize should handle lambdas
1 parent 74a40f4 commit 9f7c928

File tree

1 file changed

+142
-60
lines changed

1 file changed

+142
-60
lines changed

std/functional.d

Lines changed: 142 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,15 @@ alias pipe(fun...) = compose!(Reverse!(fun));
12451245
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
12461246
}
12471247

1248+
private template getOverloads(alias fun)
1249+
{
1250+
import std.meta : AliasSeq;
1251+
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun))))
1252+
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun));
1253+
else
1254+
alias getOverloads = AliasSeq!fun;
1255+
}
1256+
12481257
/**
12491258
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
12501259
* to avoid repeated computation. The memoization structure is a hash table keyed by a
@@ -1280,87 +1289,129 @@ Note:
12801289
*/
12811290
template memoize(alias fun)
12821291
{
1283-
import std.traits : ReturnType;
1284-
// https://issues.dlang.org/show_bug.cgi?id=13580
1285-
// alias Args = Parameters!fun;
1292+
import std.traits : Parameters;
1293+
import std.meta : anySatisfy;
1294+
1295+
// Specific overloads:
1296+
alias overloads = getOverloads!fun;
1297+
static foreach (fn; overloads)
1298+
static if (is(Parameters!fn))
1299+
alias memoize = impl!(Parameters!fn);
1300+
1301+
enum isTemplate(alias a) = __traits(isTemplate, a);
1302+
static if (anySatisfy!(isTemplate, overloads))
1303+
{
1304+
// Generic implementation
1305+
alias memoize = impl;
1306+
}
12861307

1287-
ReturnType!fun memoize(Parameters!fun args)
1308+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
12881309
{
1289-
alias Args = Parameters!fun;
1290-
import std.typecons : Tuple;
1310+
import std.typecons : Tuple, tuple;
12911311
import std.traits : Unqual;
12921312

1293-
static Unqual!(ReturnType!fun)[Tuple!Args] memo;
1294-
auto t = Tuple!Args(args);
1295-
if (auto p = t in memo)
1296-
return *p;
1297-
auto r = fun(args);
1298-
memo[t] = r;
1299-
return r;
1313+
static if (args.length > 0)
1314+
{
1315+
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;
1316+
1317+
auto t = Tuple!Args(args);
1318+
if (auto p = t in memo)
1319+
return *p;
1320+
auto r = fun(args);
1321+
memo[t] = r;
1322+
return r;
1323+
}
1324+
else
1325+
{
1326+
static typeof(fun(args)) result = fun(args);
1327+
return result;
1328+
}
13001329
}
13011330
}
13021331

13031332
/// ditto
13041333
template memoize(alias fun, uint maxSize)
13051334
{
1306-
import std.traits : ReturnType;
1307-
// https://issues.dlang.org/show_bug.cgi?id=13580
1308-
// alias Args = Parameters!fun;
1309-
ReturnType!fun memoize(Parameters!fun args)
1335+
import std.traits : Parameters;
1336+
import std.meta : anySatisfy;
1337+
1338+
// Specific overloads:
1339+
alias overloads = getOverloads!fun;
1340+
static foreach (fn; overloads)
1341+
static if (is(Parameters!fn))
1342+
alias memoize = impl!(Parameters!fn);
1343+
1344+
enum isTemplate(alias a) = __traits(isTemplate, a);
1345+
static if (anySatisfy!(isTemplate, overloads))
13101346
{
1311-
import std.meta : staticMap;
1312-
import std.traits : hasIndirections, Unqual;
1313-
import std.typecons : tuple;
1314-
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
1315-
static Value[] memo;
1316-
static size_t[] initialized;
1347+
// Generic implementation
1348+
alias memoize = impl;
1349+
}
13171350

1318-
if (!memo.length)
1351+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
1352+
{
1353+
static if (args.length > 0)
13191354
{
1320-
import core.memory : GC;
1355+
import std.meta : staticMap;
1356+
import std.traits : hasIndirections, Unqual;
1357+
import std.typecons : tuple;
1358+
alias returnType = typeof(fun(args));
1359+
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
1360+
static Value[] memo;
1361+
static size_t[] initialized;
1362+
1363+
if (!memo.length)
1364+
{
1365+
import core.memory : GC;
13211366

1322-
// Ensure no allocation overflows
1323-
static assert(maxSize < size_t.max / Value.sizeof);
1324-
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
1367+
// Ensure no allocation overflows
1368+
static assert(maxSize < size_t.max / Value.sizeof);
1369+
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
13251370

1326-
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1327-
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1328-
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1329-
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1330-
}
1371+
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1372+
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1373+
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1374+
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1375+
}
13311376

1332-
import core.bitop : bt, bts;
1333-
import core.lifetime : emplace;
1377+
import core.bitop : bt, bts;
1378+
import core.lifetime : emplace;
13341379

1335-
size_t hash;
1336-
foreach (ref arg; args)
1337-
hash = hashOf(arg, hash);
1338-
// cuckoo hashing
1339-
immutable idx1 = hash % maxSize;
1340-
if (!bt(initialized.ptr, idx1))
1341-
{
1342-
emplace(&memo[idx1], args, fun(args));
1343-
// only set to initialized after setting args and value
1344-
// https://issues.dlang.org/show_bug.cgi?id=14025
1345-
bts(initialized.ptr, idx1);
1380+
size_t hash;
1381+
foreach (ref arg; args)
1382+
hash = hashOf(arg, hash);
1383+
// cuckoo hashing
1384+
immutable idx1 = hash % maxSize;
1385+
if (!bt(initialized.ptr, idx1))
1386+
{
1387+
emplace(&memo[idx1], args, fun(args));
1388+
// only set to initialized after setting args and value
1389+
// https://issues.dlang.org/show_bug.cgi?id=14025
1390+
bts(initialized.ptr, idx1);
1391+
return memo[idx1].res;
1392+
}
1393+
else if (memo[idx1].args == args)
1394+
return memo[idx1].res;
1395+
// FNV prime
1396+
immutable idx2 = (hash * 16_777_619) % maxSize;
1397+
if (!bt(initialized.ptr, idx2))
1398+
{
1399+
emplace(&memo[idx2], memo[idx1]);
1400+
bts(initialized.ptr, idx2);
1401+
}
1402+
else if (memo[idx2].args == args)
1403+
return memo[idx2].res;
1404+
else if (idx1 != idx2)
1405+
memo[idx2] = memo[idx1];
1406+
1407+
memo[idx1] = Value(args, fun(args));
13461408
return memo[idx1].res;
13471409
}
1348-
else if (memo[idx1].args == args)
1349-
return memo[idx1].res;
1350-
// FNV prime
1351-
immutable idx2 = (hash * 16_777_619) % maxSize;
1352-
if (!bt(initialized.ptr, idx2))
1410+
else
13531411
{
1354-
emplace(&memo[idx2], memo[idx1]);
1355-
bts(initialized.ptr, idx2);
1412+
static typeof(fun(args)) result = fun(args);
1413+
return result;
13561414
}
1357-
else if (memo[idx2].args == args)
1358-
return memo[idx2].res;
1359-
else if (idx1 != idx2)
1360-
memo[idx2] = memo[idx1];
1361-
1362-
memo[idx1] = Value(args, fun(args));
1363-
return memo[idx1].res;
13641415
}
13651416
}
13661417

@@ -1420,6 +1471,37 @@ unittest
14201471
assert(fact(10) == 3628800);
14211472
}
14221473

1474+
// Issue 20099
1475+
@system unittest // not @safe due to memoize
1476+
{
1477+
int i = 3;
1478+
alias a = memoize!((n) => i + n);
1479+
alias b = memoize!((n) => i + n, 3);
1480+
1481+
assert(a(3) == 6);
1482+
assert(b(3) == 6);
1483+
}
1484+
1485+
@system unittest // not @safe due to memoize
1486+
{
1487+
static Object objNum(int a) { return new Object(); }
1488+
assert(memoize!objNum(0) is memoize!objNum(0U));
1489+
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
1490+
}
1491+
1492+
@system unittest // not @safe due to memoize
1493+
{
1494+
struct S
1495+
{
1496+
static int fun() { return 0; }
1497+
static int fun(int i) { return 1; }
1498+
}
1499+
assert(memoize!(S.fun)() == 0);
1500+
assert(memoize!(S.fun)(3) == 1);
1501+
assert(memoize!(S.fun, 3)() == 0);
1502+
assert(memoize!(S.fun, 3)(3) == 1);
1503+
}
1504+
14231505
@system unittest // not @safe due to memoize
14241506
{
14251507
import core.math : sqrt;

0 commit comments

Comments
 (0)