|
41 | 41 |
|
42 | 42 | package com.oracle.graal.python.builtins.modules;
|
43 | 43 |
|
| 44 | +import static com.oracle.graal.python.builtins.PythonBuiltinClassType.TypeError; |
| 45 | + |
44 | 46 | import java.math.BigInteger;
|
45 | 47 | import java.util.List;
|
46 | 48 |
|
47 | 49 | import com.oracle.graal.python.builtins.Builtin;
|
48 | 50 | import com.oracle.graal.python.builtins.CoreFunctions;
|
49 | 51 | import com.oracle.graal.python.builtins.PythonBuiltins;
|
50 | 52 | import com.oracle.graal.python.builtins.objects.PNone;
|
| 53 | +import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAccessLibrary; |
| 54 | +import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAcquireLibrary; |
51 | 55 | import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
|
52 | 56 | import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
|
53 | 57 | import com.oracle.graal.python.builtins.objects.dict.PDict;
|
|
60 | 64 | import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
|
61 | 65 | import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
|
62 | 66 | import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
|
| 67 | +import com.oracle.graal.python.nodes.util.CannotCastException; |
| 68 | +import com.oracle.graal.python.nodes.util.CastToJavaStringNode; |
63 | 69 | import com.oracle.graal.python.runtime.sequence.PSequence;
|
64 | 70 | import com.oracle.truffle.api.CompilerDirectives;
|
65 | 71 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
|
|
70 | 76 | import com.oracle.truffle.api.dsl.Specialization;
|
71 | 77 | import com.oracle.truffle.api.dsl.TypeSystemReference;
|
72 | 78 | import com.oracle.truffle.api.frame.VirtualFrame;
|
| 79 | +import com.oracle.truffle.api.library.CachedLibrary; |
73 | 80 |
|
74 | 81 | @CoreFunctions(defineModule = OperatorModuleBuiltins.MODULE_NAME)
|
75 | 82 | public class OperatorModuleBuiltins extends PythonBuiltins {
|
@@ -174,10 +181,60 @@ public Object doObject(VirtualFrame frame, Object value, Object index,
|
174 | 181 | public abstract static class CompareDigestNode extends PythonBinaryBuiltinNode {
|
175 | 182 |
|
176 | 183 | @Specialization
|
177 |
| - public boolean doString(String arg1, String arg2) { |
178 |
| - return arg1.equals(arg2); |
| 184 | + public boolean compare(Object left, Object right, |
| 185 | + @Cached CastToJavaStringNode cast, |
| 186 | + @CachedLibrary(limit = "3") PythonBufferAcquireLibrary bufferAcquireLib, |
| 187 | + @CachedLibrary(limit = "3") PythonBufferAccessLibrary bufferLib) { |
| 188 | + try { |
| 189 | + String leftString = cast.execute(left); |
| 190 | + String rightString = cast.execute(right); |
| 191 | + return tscmp(leftString, rightString); |
| 192 | + } catch (CannotCastException e) { |
| 193 | + if (!bufferAcquireLib.hasBuffer(left) || !bufferAcquireLib.hasBuffer(right)) { |
| 194 | + throw raise(TypeError, "unsupported operand types(s) or combination of types: '%p' and '%p'", left, right); |
| 195 | + } |
| 196 | + Object leftBuffer = bufferAcquireLib.acquireReadonly(left); |
| 197 | + try { |
| 198 | + Object rightBuffer = bufferAcquireLib.acquireReadonly(right); |
| 199 | + try { |
| 200 | + return tscmp(bufferLib.getCopiedByteArray(leftBuffer), bufferLib.getCopiedByteArray(rightBuffer)); |
| 201 | + } finally { |
| 202 | + bufferLib.release(rightBuffer); |
| 203 | + } |
| 204 | + } finally { |
| 205 | + bufferLib.release(leftBuffer); |
| 206 | + } |
| 207 | + } |
179 | 208 | }
|
180 | 209 |
|
| 210 | + // Comparison that's safe against timing attacks |
| 211 | + @TruffleBoundary |
| 212 | + private boolean tscmp(String leftIn, String right) { |
| 213 | + String left = leftIn; |
| 214 | + int result = 0; |
| 215 | + if (left.length() != right.length()) { |
| 216 | + left = right; |
| 217 | + result = 1; |
| 218 | + } |
| 219 | + for (int i = 0; i < left.length(); i++) { |
| 220 | + result |= left.charAt(i) ^ right.charAt(i); |
| 221 | + } |
| 222 | + return result == 0; |
| 223 | + } |
| 224 | + |
| 225 | + @TruffleBoundary |
| 226 | + private boolean tscmp(byte[] leftIn, byte[] right) { |
| 227 | + byte[] left = leftIn; |
| 228 | + int result = 0; |
| 229 | + if (left.length != right.length) { |
| 230 | + left = right; |
| 231 | + result = 1; |
| 232 | + } |
| 233 | + for (int i = 0; i < left.length; i++) { |
| 234 | + result |= left[i] ^ right[i]; |
| 235 | + } |
| 236 | + return result == 0; |
| 237 | + } |
181 | 238 | }
|
182 | 239 |
|
183 | 240 | @Builtin(name = "index", minNumOfPositionalArgs = 1)
|
|
0 commit comments