Skip to content

Commit 0521623

Browse files
committed
Properly implement compare_digest
1 parent 557bf31 commit 0521623

File tree

1 file changed

+59
-2
lines changed

1 file changed

+59
-2
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/OperatorModuleBuiltins.java

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,17 @@
4141

4242
package com.oracle.graal.python.builtins.modules;
4343

44+
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.TypeError;
45+
4446
import java.math.BigInteger;
4547
import java.util.List;
4648

4749
import com.oracle.graal.python.builtins.Builtin;
4850
import com.oracle.graal.python.builtins.CoreFunctions;
4951
import com.oracle.graal.python.builtins.PythonBuiltins;
5052
import com.oracle.graal.python.builtins.objects.PNone;
53+
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAccessLibrary;
54+
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAcquireLibrary;
5155
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
5256
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
5357
import com.oracle.graal.python.builtins.objects.dict.PDict;
@@ -60,6 +64,8 @@
6064
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6165
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6266
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
67+
import com.oracle.graal.python.nodes.util.CannotCastException;
68+
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
6369
import com.oracle.graal.python.runtime.sequence.PSequence;
6470
import com.oracle.truffle.api.CompilerDirectives;
6571
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -70,6 +76,7 @@
7076
import com.oracle.truffle.api.dsl.Specialization;
7177
import com.oracle.truffle.api.dsl.TypeSystemReference;
7278
import com.oracle.truffle.api.frame.VirtualFrame;
79+
import com.oracle.truffle.api.library.CachedLibrary;
7380

7481
@CoreFunctions(defineModule = OperatorModuleBuiltins.MODULE_NAME)
7582
public class OperatorModuleBuiltins extends PythonBuiltins {
@@ -174,10 +181,60 @@ public Object doObject(VirtualFrame frame, Object value, Object index,
174181
public abstract static class CompareDigestNode extends PythonBinaryBuiltinNode {
175182

176183
@Specialization
177-
public boolean doString(String arg1, String arg2) {
178-
return arg1.equals(arg2);
184+
public boolean compare(Object left, Object right,
185+
@Cached CastToJavaStringNode cast,
186+
@CachedLibrary(limit = "3") PythonBufferAcquireLibrary bufferAcquireLib,
187+
@CachedLibrary(limit = "3") PythonBufferAccessLibrary bufferLib) {
188+
try {
189+
String leftString = cast.execute(left);
190+
String rightString = cast.execute(right);
191+
return tscmp(leftString, rightString);
192+
} catch (CannotCastException e) {
193+
if (!bufferAcquireLib.hasBuffer(left) || !bufferAcquireLib.hasBuffer(right)) {
194+
throw raise(TypeError, "unsupported operand types(s) or combination of types: '%p' and '%p'", left, right);
195+
}
196+
Object leftBuffer = bufferAcquireLib.acquireReadonly(left);
197+
try {
198+
Object rightBuffer = bufferAcquireLib.acquireReadonly(right);
199+
try {
200+
return tscmp(bufferLib.getCopiedByteArray(leftBuffer), bufferLib.getCopiedByteArray(rightBuffer));
201+
} finally {
202+
bufferLib.release(rightBuffer);
203+
}
204+
} finally {
205+
bufferLib.release(leftBuffer);
206+
}
207+
}
179208
}
180209

210+
// Comparison that's safe against timing attacks
211+
@TruffleBoundary
212+
private boolean tscmp(String leftIn, String right) {
213+
String left = leftIn;
214+
int result = 0;
215+
if (left.length() != right.length()) {
216+
left = right;
217+
result = 1;
218+
}
219+
for (int i = 0; i < left.length(); i++) {
220+
result |= left.charAt(i) ^ right.charAt(i);
221+
}
222+
return result == 0;
223+
}
224+
225+
@TruffleBoundary
226+
private boolean tscmp(byte[] leftIn, byte[] right) {
227+
byte[] left = leftIn;
228+
int result = 0;
229+
if (left.length != right.length) {
230+
left = right;
231+
result = 1;
232+
}
233+
for (int i = 0; i < left.length; i++) {
234+
result |= left[i] ^ right[i];
235+
}
236+
return result == 0;
237+
}
181238
}
182239

183240
@Builtin(name = "index", minNumOfPositionalArgs = 1)

0 commit comments

Comments
 (0)