Skip to content

Commit 82e0539

Browse files
committed
Add cache for op_sub on Representation
As the first cache for a binary special method, this exposes (and we fix) a number of bugs in handle formation and use in the call site.
1 parent fc99b99 commit 82e0539

File tree

7 files changed

+137
-67
lines changed

7 files changed

+137
-67
lines changed

rt4core/src/main/java/uk/co/farowl/vsj4/core/PyRT.java

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ public BinaryOpCallSite(SpecialMethod op) {
321321
setTarget(fallbackMH.bindTo(this));
322322
}
323323

324+
@Override
325+
public String toString() {
326+
return String.format(
327+
"BinaryOpCallSite[%s fallbacks=%s chain=%s]",
328+
op.name(), fallbackCount, chainLength);
329+
}
330+
324331
/**
325332
* Compute the result of the call for this particular pair of
326333
* arguments, and update the site to do this efficiently for the
@@ -371,22 +378,18 @@ private Object fallback(Object v, Object w) throws Throwable {
371378
// class(v) does not fix type(v).
372379
if (wType.hasFeature(TypeFlag.REPLACEABLE)) {
373380
// class(w) does not fix type(w).
374-
try {
375-
/*
376-
* It is complex to create a "double bounce"
377-
* handle so we compute the answer but do not
378-
* cache the method.
379-
*/
380-
Object r = dynamicResult(vType, v, wType, w);
381-
if (r != Py.NotImplemented) { return r; }
382-
} catch (EmptyException e) {}
383-
// Empty or r=NotImplemented
384-
throw op.operandError(v, w);
381+
/*
382+
* Both vMH and wRH would be bounce handles. We
383+
* currently hypothesise that it is not worth
384+
* updating the call site with such a combination:
385+
* rather just go for the answer.
386+
*/
387+
return dynamicResult(vType, v, wType, w);
385388

386389
} else {
387390
// class(w) fixes type(w).
388391
vMH = op.handle(vRep); // = op.bounce
389-
if ((wRH = rop.handle(wType)) == rop.empty) { // XXX
392+
if ((wRH = rop.handle(wRep)) == rop.empty) {
390393
// We need only consider vMH.
391394
mh = vMH;
392395
} else {
@@ -404,7 +407,7 @@ private Object fallback(Object v, Object w) throws Throwable {
404407
// class(v) fixes type(v).
405408
// class(w) does not fix type(w).
406409
wRH = rop.handle(wRep); // = op.bounce
407-
if ((vMH = op.handle(vRep)) == op.empty) { // XXX
410+
if ((vMH = op.handle(vRep)) == op.empty) {
408411
// We need only consider wRH
409412
mh = wRH;
410413
} else {
@@ -425,31 +428,46 @@ private Object fallback(Object v, Object w) throws Throwable {
425428
// class(v) fixes type(v).
426429
// class(w) fixes type(w).
427430
vMH = op.handle(vRep);
428-
if (vType == wType) {
429-
// We need only consider vMH
431+
if (vType == wType
432+
|| (wRH = rop.handle(wRep)) == rop.empty) {
433+
// We need only consider vMH (even if empty)
430434
mh = vMH;
435+
} else if (vMH == op.empty) {
436+
// We need only consider wRH
437+
mh = wRH;
438+
} else if (wType.isSubTypeOf(vType)) {
439+
// Try w.rop(v),then v.rop(w).
440+
mh = firstImplementer(wRH, vMH);
431441
} else {
432-
wRH = rop.handle(wRep);
433-
if (wType.isSubTypeOf(vType)) {
434-
// Try w.rop(v),then v.rop(w).
435-
mh = firstImplementer(wRH, vMH);
436-
} else {
437-
// Try v.op(w) then w.rop(v)
438-
mh = firstImplementer(vMH, wRH);
439-
}
442+
// Try v.op(w) then w.rop(v)
443+
mh = firstImplementer(vMH, wRH);
440444
}
441445
}
442446

443-
// MH for guarded invocation (becomes new target)
444-
// guardMH = insertArguments(CLASS2_GUARD, 0, vClass,
445-
// wClass);
446-
// targetMH = guardWithTest(guardMH, mh, getTarget());
447-
// setTarget(targetMH);
448-
// chainLength += 1;
449-
450-
MethodHandle resultMH =
451-
firstImplementer(mh, op.errorHandle());
452-
return resultMH.invokeExact(v, w);
447+
// Convert a final NotImplemented into an error
448+
mh = firstImplementer(mh, op.errorHandle());
449+
450+
/*
451+
* If the composite handle throws, it throws here and we do
452+
* not bind a new target. If it's a value-dependent one-off,
453+
* we'll get another go.
454+
*/
455+
Object r = mh.invokeExact(v, w);
456+
457+
/*
458+
* Decide whether to embed the composite handle in the
459+
* target of the site.
460+
*/
461+
if (chainLength < MAX_CHAIN) {
462+
// MH for guarded invocation (becomes new target)
463+
guardMH = insertArguments(CLASS2_GUARD, 0, vClass,
464+
wClass);
465+
targetMH = guardWithTest(guardMH, mh, getTarget());
466+
setTarget(targetMH);
467+
chainLength += 1;
468+
}
469+
470+
return r;
453471
}
454472

455473
private Object dynamicResult(BaseType vType, Object v,

rt4core/src/main/java/uk/co/farowl/vsj4/kernel/BaseType.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,9 +1263,7 @@ private void updateSpecialMethodCache(SpecialMethod sm,
12631263
List<Representation> reps = representations();
12641264
if (this instanceof ReplaceableType) {
12651265
assert reps.get(0) instanceof SharedRepresentation;
1266-
// The cache delegates to the type (always).
1267-
assert sm.handle(reps.get(0)) == sm.bounce;
1268-
// So we update the type object itself.
1266+
// Update the type object itself.
12691267
reps = List.of(this);
12701268
}
12711269
updateSpecialMethodCache(sm, result, reps);

rt4core/src/main/java/uk/co/farowl/vsj4/kernel/Representation.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,9 @@ public MethodHandle op_add() {
664664
* @return handle on {@code __rsub__} with signature
665665
* {@link Signature#BINARY}.
666666
*/
667-
@SuppressWarnings("static-method")
668-
public MethodHandle op_rsub() {
669-
return SpecialMethod.op_rsub.generic;
670-
}
667+
public MethodHandle op_rsub() { return op_rsub; }
668+
669+
private MethodHandle op_rsub;
671670

672671
/**
673672
* Return a matching implementation of {@code __sub__} with
@@ -676,10 +675,9 @@ public MethodHandle op_rsub() {
676675
* @return handle on {@code __sub__} with signature
677676
* {@link Signature#BINARY}.
678677
*/
679-
@SuppressWarnings("static-method")
680-
public MethodHandle op_sub() {
681-
return SpecialMethod.op_sub.generic;
682-
}
678+
public MethodHandle op_sub() { return op_sub; }
679+
680+
private MethodHandle op_sub;
683681

684682
/**
685683
* Return a matching implementation of {@code __rmul__} with

rt4core/src/main/java/uk/co/farowl/vsj4/kernel/SharedRepresentation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SharedRepresentation extends Representation {
3131
for (SpecialMethod sm : SpecialMethod.values()) {
3232
if (sm.hasCache()) {
3333
// Cache bounces decision to the type.
34-
sm.setCache(this, sm.bounce);
34+
sm.setBounce(this);
3535
}
3636
}
3737
}

rt4core/src/main/java/uk/co/farowl/vsj4/kernel/SpecialMethod.java

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ public enum SpecialMethod {
629629
this.signature = signature;
630630
this.methodName = dunder(methodName);
631631
this.isreflected = isreflected;
632+
// Cannot be a reflected method and have a reflection
632633
assert reflected == null || reflected.isreflected;
633634
this.reflected = reflected;
634635
// If doc is short, assume it's a symbol. Fall back on name.
@@ -979,7 +980,9 @@ public MethodHandle errorHandle() {
979980
* {@link Representation} to the given {@code MethodHandle}. In the
980981
* case of a (binary) reflected special method (like
981982
* {@code __rsub__}), the handle is transformed by swapping its
982-
* arguments.
983+
* arguments. This is the correct thing to so when the handle is on
984+
* a method definition, because wherever we invoke a reflected
985+
* handle, we do so with the receiver ({@code self}) second.
983986
* <p>
984987
* If this special method does not have a cache in
985988
* {@link Representation} objects, this is a no-op, and effectively
@@ -1013,7 +1016,7 @@ void setCache(Representation rep, MethodHandle mh) {
10131016
*
10141017
* @param rep target {@code Representation}
10151018
*/
1016-
void setGeneric(Representation rep) { setCache(rep, generic); }
1019+
void setGeneric(Representation rep) { cache.set(rep, generic); }
10171020

10181021
/**
10191022
* Set the cache for this {@code SpecialMethod} in the
@@ -1023,7 +1026,19 @@ void setCache(Representation rep, MethodHandle mh) {
10231026
*
10241027
* @param rep target {@code Representation}
10251028
*/
1026-
void setEmpty(Representation rep) { setCache(rep, this.empty); }
1029+
void setEmpty(Representation rep) { cache.set(rep, empty); }
1030+
1031+
/**
1032+
* Set the cache for this {@code SpecialMethod} in the
1033+
* {@link Representation} to be {@link #bounce}. The bounce handle
1034+
* invokes the corresponding special method cache on the type
1035+
* object.
1036+
*
1037+
* @param rep target {@code Representation}
1038+
*/
1039+
void setBounce(SharedRepresentation rep) {
1040+
cache.set(rep, bounce);
1041+
}
10271042

10281043
@Override
10291044
public java.lang.String toString() {
@@ -1468,6 +1483,9 @@ static MethodHandle bounceMH(SpecialMethod sm) {
14681483

14691484
// We aim to create:
14701485
// bounce = λ(s, ...): trampoline(sm)(type(s), s, ...)
1486+
// or for a *reflected* binary operation:
1487+
// bounce = λ(v, w): trampoline(sm)(type(w), v, w)
1488+
// because we shall call it with the receiver second.
14711489
try {
14721490
/*
14731491
* Find the trampoline method handle smt. The signature
@@ -1493,6 +1511,11 @@ static MethodHandle bounceMH(SpecialMethod sm) {
14931511
type = type.asType(MethodType.methodType(T, O));
14941512

14951513
// bounce = λ(s,...): smt(type(s),s,...)
1514+
// or bounce = λ(v,w): smt(type(w),v,w)
1515+
if (sm.isreflected) {
1516+
// type = λ(v,w): type(w)
1517+
type = MethodHandles.dropArguments(type, 0, O);
1518+
}
14961519
MethodHandle bounce =
14971520
MethodHandles.foldArguments(smt, type);
14981521

@@ -1773,8 +1796,20 @@ private static Object op_add(PyType vType, Object v, Object w)
17731796
}
17741797

17751798
@SuppressWarnings("unused")
1776-
private static Object op_radd(PyType wType, Object w, Object v)
1799+
private static Object op_radd(PyType wType, Object v, Object w)
1800+
throws Throwable {
1801+
return BaseType.cast(wType).op_radd().invokeExact(v, w);
1802+
}
1803+
1804+
@SuppressWarnings("unused")
1805+
private static Object op_sub(PyType vType, Object v, Object w)
1806+
throws Throwable {
1807+
return BaseType.cast(vType).op_sub().invokeExact(v, w);
1808+
}
1809+
1810+
@SuppressWarnings("unused")
1811+
private static Object op_rsub(PyType wType, Object v, Object w)
17771812
throws Throwable {
1778-
return BaseType.cast(wType).op_radd().invokeExact(w, v);
1813+
return BaseType.cast(wType).op_rsub().invokeExact(v, w);
17791814
}
17801815
}

rt4core/src/test/java/uk/co/farowl/vsj4/core/BinaryCallSiteTest.java

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@
3030
import org.slf4j.LoggerFactory;
3131

3232
import uk.co.farowl.vsj4.core.PyRT.BinaryOpCallSite;
33+
import uk.co.farowl.vsj4.kernel.BaseType;
34+
import uk.co.farowl.vsj4.kernel.Representation;
3335
import uk.co.farowl.vsj4.kernel.SpecialMethod;
3436
import uk.co.farowl.vsj4.kernel.SpecialMethod.Signature;
3537
import uk.co.farowl.vsj4.support.InterpreterError;
38+
import uk.co.farowl.vsj4.types.TypeFlag;
3639

3740
/**
3841
* Test of the mechanism for invoking and updating binary call sites on
@@ -344,6 +347,8 @@ void testMatchSpecial(String name, String mix,
344347
}
345348
}
346349

350+
private record ClassPair(Class<?> vClass, Class<?> wClass) {}
351+
347352
/**
348353
* Invoke a special method call site for the presented values in
349354
* order, examining fall-back and new specialisations added as
@@ -361,12 +366,14 @@ void testFallbackCounts(String name, String mix,
361366
List<Object> values) throws Throwable {
362367

363368
MethodHandle invoker = cs.dynamicInvoker();
369+
SpecialMethod op = cs.op;
370+
SpecialMethod rop = cs.rop;
364371

365372
/*
366373
* Track the classes that (we think) are cached in the call
367374
* site's handle chain.
368375
*/
369-
Set<Class<?>> cached = new HashSet<>();
376+
Set<ClassPair> chain = new HashSet<>();
370377
int lastCount = 0;
371378

372379
// Invoke for each of the values
@@ -376,25 +383,39 @@ void testFallbackCounts(String name, String mix,
376383
@SuppressWarnings("unused")
377384
Object r = invoker.invokeExact(v, w);
378385

379-
if (!cached.contains(v.getClass())) {
380-
// Uncached class so should have called
381-
// fallback.
386+
Class<?> vClass = v.getClass();
387+
Representation vRep = Abstract.representation(v);
388+
BaseType vType = vRep.pythonType(v);
389+
MethodHandle vMH = op.handle(vRep);
390+
391+
Class<?> wClass = w.getClass();
392+
Representation wRep = Abstract.representation(w);
393+
BaseType wType = wRep.pythonType(w);
394+
MethodHandle wRH = rop.handle(wRep);
395+
396+
ClassPair pair = new ClassPair(vClass, wClass);
397+
if (!chain.contains(pair)) {
398+
// Uncached class: should have called fallback.
382399
lastCount += 1;
383400
}
384401
assertEquals(lastCount, cs.fallbackCount,
385402
"fallback calls");
386403

387404
/*
388-
* If the site is not full and the inner
389-
* SpecialMethod is a cached type, it should have
390-
* been added to the chain.
405+
* If the site is not full the handle might have
406+
* been added to the chain. The rules for this may
407+
* be somewhat complicated/fluid.
391408
*/
392-
if (cached.size() < BinaryOpCallSite.MAX_CHAIN) {
393-
if (cs.op.hasCache()) {
394-
cached.add(v.getClass());
395-
}
409+
if (chain.size() >= BinaryOpCallSite.MAX_CHAIN) {
410+
// Don't embed.
411+
} else if (vType.hasFeature(TypeFlag.REPLACEABLE)
412+
&& wType.hasFeature(TypeFlag.REPLACEABLE)) {
413+
// Don't embed.
414+
} else {
415+
chain.add(pair);
396416
}
397-
assertEquals(cached.size(), cs.chainLength,
417+
418+
assertEquals(chain.size(), cs.chainLength,
398419
"chain length");
399420
}
400421
}

rt4core/src/test/java/uk/co/farowl/vsj4/core/UnaryCallSiteTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ void testFallbackCounts(String name, String mix,
366366
* Track the classes that (we think) are cached in the call
367367
* site's handle chain.
368368
*/
369-
Set<Class<?>> cached = new HashSet<>();
369+
Set<Class<?>> chain = new HashSet<>();
370370
int lastCount = 0;
371371

372372
// Invoke for each of the values
@@ -375,7 +375,7 @@ void testFallbackCounts(String name, String mix,
375375
@SuppressWarnings("unused")
376376
Object r = invoker.invokeExact(x);
377377

378-
if (!cached.contains(x.getClass())) {
378+
if (!chain.contains(x.getClass())) {
379379
// Uncached class so should have called fallback.
380380
lastCount += 1;
381381
}
@@ -387,10 +387,10 @@ void testFallbackCounts(String name, String mix,
387387
* is a cached type, it should have been added to the
388388
* chain.
389389
*/
390-
if (cached.size() < UnaryOpCallSite.MAX_CHAIN) {
391-
if (cs.op.hasCache()) { cached.add(x.getClass()); }
390+
if (chain.size() < UnaryOpCallSite.MAX_CHAIN) {
391+
if (cs.op.hasCache()) { chain.add(x.getClass()); }
392392
}
393-
assertEquals(cached.size(), cs.chainLength,
393+
assertEquals(chain.size(), cs.chainLength,
394394
"chain length");
395395
}
396396
}

0 commit comments

Comments
 (0)