Skip to content

Commit 258735c

Browse files
timfelboris-spas
authored andcommitted
fix wrong specialization in random.seed
1 parent 47426a1 commit 258735c

File tree

1 file changed

+46
-19
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/random

1 file changed

+46
-19
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/random/RandomBuiltins.java

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,19 @@
5151
import com.oracle.graal.python.builtins.objects.ints.PInt;
5252
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
5353
import com.oracle.graal.python.nodes.PGuards;
54+
import com.oracle.graal.python.nodes.SpecialMethodNames;
5455
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
5556
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
57+
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
5658
import com.oracle.graal.python.runtime.exception.PythonErrorType;
59+
import com.oracle.truffle.api.CompilerDirectives;
60+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
5761
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
58-
import com.oracle.truffle.api.dsl.Cached;
62+
import com.oracle.truffle.api.dsl.Fallback;
5963
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6064
import com.oracle.truffle.api.dsl.NodeFactory;
6165
import com.oracle.truffle.api.dsl.Specialization;
66+
import com.oracle.truffle.api.dsl.TypeSystemReference;
6267
import com.oracle.truffle.api.nodes.UnexpectedResultException;
6368

6469
@CoreFunctions(extendClasses = PythonBuiltinClassType.PRandom)
@@ -70,6 +75,7 @@ protected List<? extends NodeFactory<? extends PythonBuiltinNode>> getNodeFactor
7075

7176
@Builtin(name = "seed", fixedNumOfPositionalArgs = 2)
7277
@GenerateNodeFactory
78+
@TypeSystemReference(PythonArithmeticTypes.class)
7379
public abstract static class SeedNode extends PythonBuiltinNode {
7480

7581
@Specialization
@@ -80,7 +86,7 @@ public PNone seed(PRandom random, @SuppressWarnings("unused") PNone none) {
8086
}
8187

8288
@Specialization
83-
public PNone seed(PRandom random, int inputSeed) {
89+
public PNone seed(PRandom random, long inputSeed) {
8490
random.setSeed(inputSeed);
8591
return PNone.NONE;
8692
}
@@ -97,26 +103,47 @@ public PNone seed(PRandom random, double inputSeed) {
97103
return PNone.NONE;
98104
}
99105

100-
@Specialization(rewriteOn = UnexpectedResultException.class)
101-
public PNone seedObject(PRandom random, Object inputSeed,
102-
@Cached("create(__HASH__)") LookupAndCallUnaryNode callHash) throws UnexpectedResultException {
103-
long hash = callHash.executeLong(inputSeed);
104-
random.setSeed(hash);
105-
return PNone.NONE;
106-
}
106+
@CompilationFinal boolean gotUnexpectedHashResult = false;
107+
@Child LookupAndCallUnaryNode callHash;
107108

108-
@Specialization(replaces = "seedObject")
109-
public PNone seedNonLong(PRandom random, Object inputSeed,
110-
@Cached("create(__HASH__)") LookupAndCallUnaryNode callHash) {
111-
Object object = callHash.executeObject(inputSeed);
112-
if (PGuards.isInteger(object)) {
113-
random.setSeed(((Number) object).intValue());
114-
} else if (PGuards.isPInt(object)) {
115-
random.setSeed(((PInt) object).intValue());
109+
@Fallback
110+
public PNone seedNonLong(Object random, Object inputSeed) {
111+
if (random instanceof PRandom) {
112+
if (callHash == null) {
113+
CompilerDirectives.transferToInterpreterAndInvalidate();
114+
callHash = insert(LookupAndCallUnaryNode.create(SpecialMethodNames.__HASH__));
115+
}
116+
Object hashResult = null;
117+
if (!gotUnexpectedHashResult) {
118+
try {
119+
long hash = callHash.executeLong(inputSeed);
120+
((PRandom) random).setSeed(hash);
121+
return PNone.NONE;
122+
} catch (UnexpectedResultException e) {
123+
CompilerDirectives.transferToInterpreterAndInvalidate();
124+
gotUnexpectedHashResult = true;
125+
hashResult = e.getResult();
126+
}
127+
}
128+
if (gotUnexpectedHashResult) {
129+
if (hashResult == null) {
130+
hashResult = callHash.executeObject(inputSeed);
131+
}
132+
if (PGuards.isInteger(hashResult)) {
133+
((PRandom) random).setSeed(((Number) hashResult).intValue());
134+
} else if (PGuards.isPInt(hashResult)) {
135+
((PRandom) random).setSeed(((PInt) hashResult).intValue());
136+
} else {
137+
throw raise(PythonErrorType.TypeError, "__hash__ method should return an integer");
138+
}
139+
return PNone.NONE;
140+
} else {
141+
assert false : "cannot reach here";
142+
return PNone.NONE;
143+
}
116144
} else {
117-
throw raise(PythonErrorType.TypeError, "__hash__ method should return an integer");
145+
throw raise(PythonErrorType.TypeError, "descriptor 'seed' requires a '_random.Random' object but received a '%p'", random);
118146
}
119-
return PNone.NONE;
120147
}
121148
}
122149

0 commit comments

Comments
 (0)