Skip to content

Commit 4fdf3a7

Browse files
committed
Fix issue 20099
1 parent 45e3ca3 commit 4fdf3a7

File tree

1 file changed

+138
-62
lines changed

1 file changed

+138
-62
lines changed

std/functional.d

Lines changed: 138 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,15 @@ alias pipe(fun...) = compose!(Reverse!(fun));
12251225
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
12261226
}
12271227

1228+
private template getOverloads(alias fun)
1229+
{
1230+
import std.meta : AliasSeq;
1231+
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun))))
1232+
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun));
1233+
else
1234+
alias getOverloads = AliasSeq!fun;
1235+
}
1236+
12281237
/**
12291238
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
12301239
* to avoid repeated computation. The memoization structure is a hash table keyed by a
@@ -1260,87 +1269,123 @@ Note:
12601269
*/
12611270
template memoize(alias fun)
12621271
{
1263-
import std.traits : ReturnType;
1264-
// https://issues.dlang.org/show_bug.cgi?id=13580
1265-
// alias Args = Parameters!fun;
1266-
1267-
ReturnType!fun memoize(Parameters!fun args)
1272+
import std.traits : Parameters;
1273+
import std.meta : anySatisfy;
1274+
1275+
// Specific overloads:
1276+
alias overloads = getOverloads!fun;
1277+
static foreach (fn; overloads)
1278+
static if (is(Parameters!fn))
1279+
alias memoize = impl!(Parameters!fn);
1280+
1281+
enum isTemplate(alias a) = __traits(isTemplate, a);
1282+
static if (anySatisfy!(isTemplate, overloads))
12681283
{
1269-
alias Args = Parameters!fun;
1270-
import std.typecons : Tuple;
1284+
// Generic implementation
1285+
alias memoize = impl;
1286+
}
1287+
1288+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
1289+
{
1290+
import std.typecons : Tuple, tuple;
12711291
import std.traits : Unqual;
12721292

1273-
static Unqual!(ReturnType!fun)[Tuple!Args] memo;
1274-
auto t = Tuple!Args(args);
1275-
if (auto p = t in memo)
1276-
return *p;
1277-
auto r = fun(args);
1278-
memo[t] = r;
1279-
return r;
1293+
static if (args.length > 0)
1294+
{
1295+
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;
1296+
return memo.require(tuple(args), fun(args));
1297+
}
1298+
else
1299+
{
1300+
static typeof(fun(args)) result = fun(args);
1301+
return result;
1302+
}
12801303
}
12811304
}
12821305

12831306
/// ditto
12841307
template memoize(alias fun, uint maxSize)
12851308
{
1286-
import std.traits : ReturnType;
1287-
// https://issues.dlang.org/show_bug.cgi?id=13580
1288-
// alias Args = Parameters!fun;
1289-
ReturnType!fun memoize(Parameters!fun args)
1309+
import std.traits : Parameters;
1310+
import std.meta : anySatisfy;
1311+
1312+
// Specific overloads:
1313+
alias overloads = getOverloads!fun;
1314+
static foreach (fn; overloads)
1315+
static if (is(Parameters!fn))
1316+
alias memoize = impl!(Parameters!fn);
1317+
1318+
enum isTemplate(alias a) = __traits(isTemplate, a);
1319+
static if (anySatisfy!(isTemplate, overloads))
12901320
{
1291-
import std.meta : staticMap;
1292-
import std.traits : hasIndirections, Unqual;
1293-
import std.typecons : tuple;
1294-
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
1295-
static Value[] memo;
1296-
static size_t[] initialized;
1297-
1298-
if (!memo.length)
1321+
// Generic implementation
1322+
alias memoize = impl;
1323+
}
1324+
1325+
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
1326+
{
1327+
static if (args.length > 0)
12991328
{
1300-
import core.memory : GC;
1329+
import std.meta : staticMap;
1330+
import std.traits : hasIndirections, Unqual;
1331+
import std.typecons : tuple;
1332+
alias returnType = typeof(fun(args));
1333+
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
1334+
static Value[] memo;
1335+
static size_t[] initialized;
1336+
1337+
if (!memo.length)
1338+
{
1339+
import core.memory : GC;
13011340

1302-
// Ensure no allocation overflows
1303-
static assert(maxSize < size_t.max / Value.sizeof);
1304-
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
1341+
// Ensure no allocation overflows
1342+
static assert(maxSize < size_t.max / Value.sizeof);
1343+
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
13051344

1306-
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1307-
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1308-
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1309-
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1310-
}
1345+
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
1346+
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
1347+
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
1348+
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
1349+
}
13111350

1312-
import core.bitop : bt, bts;
1313-
import std.conv : emplace;
1351+
import core.bitop : bt, bts;
1352+
import std.conv : emplace;
13141353

1315-
size_t hash;
1316-
foreach (ref arg; args)
1317-
hash = hashOf(arg, hash);
1318-
// cuckoo hashing
1319-
immutable idx1 = hash % maxSize;
1320-
if (!bt(initialized.ptr, idx1))
1321-
{
1322-
emplace(&memo[idx1], args, fun(args));
1323-
// only set to initialized after setting args and value
1324-
// https://issues.dlang.org/show_bug.cgi?id=14025
1325-
bts(initialized.ptr, idx1);
1354+
size_t hash;
1355+
foreach (ref arg; args)
1356+
hash = hashOf(arg, hash);
1357+
// cuckoo hashing
1358+
immutable idx1 = hash % maxSize;
1359+
if (!bt(initialized.ptr, idx1))
1360+
{
1361+
emplace(&memo[idx1], args, fun(args));
1362+
// only set to initialized after setting args and value
1363+
// https://issues.dlang.org/show_bug.cgi?id=14025
1364+
bts(initialized.ptr, idx1);
1365+
return memo[idx1].res;
1366+
}
1367+
else if (memo[idx1].args == args)
1368+
return memo[idx1].res;
1369+
// FNV prime
1370+
immutable idx2 = (hash * 16_777_619) % maxSize;
1371+
if (!bt(initialized.ptr, idx2))
1372+
{
1373+
emplace(&memo[idx2], memo[idx1]);
1374+
bts(initialized.ptr, idx2);
1375+
}
1376+
else if (memo[idx2].args == args)
1377+
return memo[idx2].res;
1378+
else if (idx1 != idx2)
1379+
memo[idx2] = memo[idx1];
1380+
1381+
memo[idx1] = Value(args, fun(args));
13261382
return memo[idx1].res;
13271383
}
1328-
else if (memo[idx1].args == args)
1329-
return memo[idx1].res;
1330-
// FNV prime
1331-
immutable idx2 = (hash * 16_777_619) % maxSize;
1332-
if (!bt(initialized.ptr, idx2))
1384+
else
13331385
{
1334-
emplace(&memo[idx2], memo[idx1]);
1335-
bts(initialized.ptr, idx2);
1386+
static typeof(fun(args)) result = fun(args);
1387+
return result;
13361388
}
1337-
else if (memo[idx2].args == args)
1338-
return memo[idx2].res;
1339-
else if (idx1 != idx2)
1340-
memo[idx2] = memo[idx1];
1341-
1342-
memo[idx1] = Value(args, fun(args));
1343-
return memo[idx1].res;
13441389
}
13451390
}
13461391

@@ -1400,6 +1445,37 @@ unittest
14001445
assert(fact(10) == 3628800);
14011446
}
14021447

1448+
// Issue 20099
1449+
@system unittest // not @safe due to memoize
1450+
{
1451+
int i = 3;
1452+
alias a = memoize!((n) => i + n);
1453+
alias b = memoize!((n) => i + n, 3);
1454+
1455+
assert(a(3) == 6);
1456+
assert(b(3) == 6);
1457+
}
1458+
1459+
@system unittest // not @safe due to memoize
1460+
{
1461+
static Object objNum(int a) { return new Object(); }
1462+
assert(memoize!objNum(0) is memoize!objNum(0U));
1463+
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
1464+
}
1465+
1466+
@system unittest // not @safe due to memoize
1467+
{
1468+
struct S
1469+
{
1470+
static int fun() { return 0; }
1471+
static int fun(int i) { return 1; }
1472+
}
1473+
assert(memoize!(S.fun)() == 0);
1474+
assert(memoize!(S.fun)(3) == 1);
1475+
assert(memoize!(S.fun, 3)() == 0);
1476+
assert(memoize!(S.fun, 3)(3) == 1);
1477+
}
1478+
14031479
@system unittest // not @safe due to memoize
14041480
{
14051481
import core.math : sqrt;

0 commit comments

Comments
 (0)