Skip to content

Commit 1c5f483

Browse files
committed
Fix: _codecs.utf_8_decode could not handle incomplete code points.
1 parent 7c1a29f commit 1c5f483

File tree

3 files changed

+74
-62
lines changed

3 files changed

+74
-62
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_codecs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def test_decode():
4242
# assert codecs.decode(b'[\xff]', 'ascii', errors='ignore') == '[]'
4343
assert codecs.decode(b'[]', 'ascii') == '[]'
4444

45+
data0 = b'\xc5'
46+
data1 = b'\x91'
47+
assert codecs.utf_8_decode(data0, "strict") == ('', 0)
48+
assert codecs.utf_8_decode(data0 + data1, "strict") == ('ő', 2)
49+
assert_raises(UnicodeDecodeError, codecs.utf_8_decode, data0, "strict", True)
50+
4551

4652
def test_encode():
4753
import codecs

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/CodecsModuleBuiltins.java

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -49,6 +49,7 @@
4949
import java.nio.CharBuffer;
5050
import java.nio.charset.CharacterCodingException;
5151
import java.nio.charset.Charset;
52+
import java.nio.charset.CoderResult;
5253
import java.nio.charset.CodingErrorAction;
5354
import java.nio.charset.StandardCharsets;
5455
import java.util.Arrays;
@@ -67,10 +68,13 @@
6768
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GetInternalByteArrayNode;
6869
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodesFactory.GetInternalByteArrayNodeGen;
6970
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
71+
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
7072
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
7173
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
7274
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
7375
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
76+
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
77+
import com.oracle.graal.python.nodes.util.CastToJavaStringNodeGen;
7478
import com.oracle.graal.python.runtime.PythonCore;
7579
import com.oracle.truffle.api.CompilerDirectives;
7680
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -496,90 +500,92 @@ private Object[] encodeString(String self, String errors) {
496500

497501
}
498502

499-
// _codecs.decode(obj, encoding='utf-8', errors='strict')
500-
@Builtin(name = "__truffle_decode", minNumOfPositionalArgs = 1, parameterNames = {"obj", "encoding", "errors"})
503+
// _codecs.decode(obj, encoding='utf-8', errors='strict', final=False)
504+
@Builtin(name = "__truffle_decode", minNumOfPositionalArgs = 1, parameterNames = {"obj", "encoding", "errors", "final"})
501505
@GenerateNodeFactory
502506
abstract static class CodecsDecodeNode extends EncodeBaseNode {
503507
@Child private GetInternalByteArrayNode toByteArrayNode;
508+
@Child private CastToJavaStringNode castEncodingToStringNode;
509+
@Child private CastToBooleanNode castToBooleanNode;
504510

505511
@Specialization
506-
Object decode(PIBytesLike bytes, @SuppressWarnings("unused") PNone encoding, @SuppressWarnings("unused") PNone errors) {
507-
byte[] decoded = getBytes(bytes);
508-
String string = decodeBytes(decoded, "utf-8", "strict");
509-
return factory().createTuple(new Object[]{string, decoded.length});
512+
Object decode(VirtualFrame frame, PIBytesLike bytes, @SuppressWarnings("unused") PNone encoding, @SuppressWarnings("unused") PNone errors, Object finalData) {
513+
ByteBuffer decoded = getBytes(bytes);
514+
String string = decodeBytes(decoded, "utf-8", "strict", castToBoolean(frame, finalData));
515+
return factory().createTuple(new Object[]{string, decoded.position()});
510516
}
511517

512518
@Specialization(guards = {"isString(encoding)"})
513-
Object decode(PIBytesLike bytes, Object encoding, @SuppressWarnings("unused") PNone errors,
514-
@Cached("createClassProfile()") ValueProfile encodingTypeProfile) {
515-
Object profiledEncoding = encodingTypeProfile.profile(encoding);
516-
byte[] decoded = getBytes(bytes);
517-
String string = decodeBytesStrict(decoded, profiledEncoding);
518-
return factory().createTuple(new Object[]{string, decoded.length});
519+
Object decode(VirtualFrame frame, PIBytesLike bytes, Object encoding, @SuppressWarnings("unused") PNone errors, Object finalData) {
520+
ByteBuffer decoded = getBytes(bytes);
521+
String string = decodeBytes(decoded, castToString(encoding), "strict", castToBoolean(frame, finalData));
522+
return factory().createTuple(new Object[]{string, decoded.position()});
519523
}
520524

521525
@Specialization(guards = {"isString(errors)"})
522-
Object decode(PIBytesLike bytes, @SuppressWarnings("unused") PNone encoding, Object errors,
523-
@Cached("createClassProfile()") ValueProfile errorsTypeProfile) {
524-
Object profiledErrors = errorsTypeProfile.profile(errors);
525-
byte[] decoded = getBytes(bytes);
526-
String string = decodeBytesUTF8(decoded, profiledErrors);
527-
return factory().createTuple(new Object[]{string, decoded.length});
526+
Object decode(VirtualFrame frame, PIBytesLike bytes, @SuppressWarnings("unused") PNone encoding, Object errors, Object finalData) {
527+
ByteBuffer decoded = getBytes(bytes);
528+
String string = decodeBytes(decoded, "utf-8", castToString(errors), castToBoolean(frame, finalData));
529+
return factory().createTuple(new Object[]{string, decoded.position()});
528530
}
529531

530532
@Specialization(guards = {"isString(encoding)", "isString(errors)"})
531-
Object decode(PIBytesLike bytes, Object encoding, Object errors,
532-
@Cached("createClassProfile()") ValueProfile encodingTypeProfile,
533-
@Cached("createClassProfile()") ValueProfile errorsTypeProfile) {
534-
Object profiledEncoding = encodingTypeProfile.profile(encoding);
535-
Object profiledErrors = errorsTypeProfile.profile(errors);
536-
byte[] decoded = getBytes(bytes);
537-
String string = decodeBytes(decoded, profiledEncoding, profiledErrors);
538-
return factory().createTuple(new Object[]{string, decoded.length});
533+
Object decode(VirtualFrame frame, PIBytesLike bytes, Object encoding, Object errors, Object finalData) {
534+
ByteBuffer decoded = getBytes(bytes);
535+
String string = decodeBytes(decoded, castToString(encoding), castToString(errors), castToBoolean(frame, finalData));
536+
return factory().createTuple(new Object[]{string, decoded.position()});
539537
}
540538

541539
@Fallback
542-
Object decode(Object bytes, @SuppressWarnings("unused") Object encoding, @SuppressWarnings("unused") Object errors) {
540+
Object decode(Object bytes, @SuppressWarnings("unused") Object encoding, @SuppressWarnings("unused") Object errors, @SuppressWarnings("unused") Object finalData) {
543541
throw raise(TypeError, "a bytes-like object is required, not '%p'", bytes);
544542
}
545543

546-
private byte[] getBytes(PIBytesLike bytesLike) {
547-
if (toByteArrayNode == null) {
548-
CompilerDirectives.transferToInterpreterAndInvalidate();
549-
toByteArrayNode = insert(GetInternalByteArrayNodeGen.create());
550-
}
551-
return toByteArrayNode.execute(bytesLike.getSequenceStorage());
544+
@TruffleBoundary
545+
private static ByteBuffer wrap(byte[] bytes) {
546+
return ByteBuffer.wrap(bytes);
552547
}
553548

554549
@TruffleBoundary
555-
String decodeBytes(byte[] bytes, Object profiledEncoding, Object profiledErrors) {
556-
return decodeBytes(bytes, profiledEncoding.toString(), profiledErrors.toString());
550+
String decodeBytes(ByteBuffer byteBuffer, String encoding, String errors, boolean finalData) {
551+
CodingErrorAction errorAction = convertCodingErrorAction(errors);
552+
Charset charset = getCharset(encoding);
553+
if (charset == null) {
554+
throw raise(LookupError, "unknown encoding: %s", encoding);
555+
}
556+
CharBuffer decoded = CharBuffer.allocate(byteBuffer.capacity());
557+
CoderResult result = charset.newDecoder().onMalformedInput(errorAction).onUnmappableCharacter(errorAction).decode(byteBuffer, decoded, finalData);
558+
if (result.isError()) {
559+
throw raise(UnicodeDecodeError, result.toString());
560+
}
561+
return String.valueOf(decoded.flip());
557562
}
558563

559-
@TruffleBoundary
560-
String decodeBytesStrict(byte[] bytes, Object profiledEncoding) {
561-
return decodeBytes(bytes, profiledEncoding.toString(), "strict");
564+
private ByteBuffer getBytes(PIBytesLike bytesLike) {
565+
if (toByteArrayNode == null) {
566+
CompilerDirectives.transferToInterpreterAndInvalidate();
567+
toByteArrayNode = insert(GetInternalByteArrayNodeGen.create());
568+
}
569+
return wrap(toByteArrayNode.execute(bytesLike.getSequenceStorage()));
562570
}
563571

564-
@TruffleBoundary
565-
String decodeBytesUTF8(byte[] bytes, Object profiledErrors) {
566-
return decodeBytes(bytes, "utf-8", profiledErrors.toString());
572+
private String castToString(Object encodingObj) {
573+
if (castEncodingToStringNode == null) {
574+
CompilerDirectives.transferToInterpreterAndInvalidate();
575+
castEncodingToStringNode = insert(CastToJavaStringNodeGen.create());
576+
}
577+
return castEncodingToStringNode.execute(encodingObj);
567578
}
568579

569-
@TruffleBoundary
570-
String decodeBytes(byte[] bytes, String encoding, String errors) {
571-
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
572-
CodingErrorAction errorAction = convertCodingErrorAction(errors);
573-
Charset charset = getCharset(encoding);
574-
if (charset == null) {
575-
throw raise(LookupError, "unknown encoding: %s", encoding);
580+
private boolean castToBoolean(VirtualFrame frame, Object object) {
581+
if(object == PNone.NO_VALUE) {
582+
return false;
576583
}
577-
try {
578-
CharBuffer decoded = charset.newDecoder().onMalformedInput(errorAction).onUnmappableCharacter(errorAction).decode(byteBuffer);
579-
return String.valueOf(decoded);
580-
} catch (CharacterCodingException e) {
581-
throw raise(UnicodeDecodeError, e);
584+
if (castToBooleanNode == null) {
585+
CompilerDirectives.transferToInterpreterAndInvalidate();
586+
castToBooleanNode = insert(CastToBooleanNode.createIfTrueNode());
582587
}
588+
return castToBooleanNode.executeBoolean(frame, object);
583589
}
584590
}
585591

graalpython/lib-graalpython/_codecs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def utf_8_encode(string, errors=None):
148148

149149
@__builtin__
150150
def utf_8_decode(string, errors=None, final=False):
151-
return __truffle_decode(string, "utf-8", errors)
151+
return __truffle_decode(string, "utf-8", errors, final)
152152

153153

154154
@__builtin__
@@ -158,7 +158,7 @@ def utf_7_encode(string, errors=None):
158158

159159
@__builtin__
160160
def utf_7_decode(string, errors=None, final=False):
161-
return __truffle_decode(string, "utf-7", errors)
161+
return __truffle_decode(string, "utf-7", errors, final)
162162

163163

164164
@__builtin__
@@ -168,7 +168,7 @@ def utf_16_encode(string, errors=None, byteorder=0):
168168

169169
@__builtin__
170170
def utf_16_decode(string, errors=None, final=False):
171-
return __truffle_decode(string, "utf-16", errors)
171+
return __truffle_decode(string, "utf-16", errors, final)
172172

173173

174174
@__builtin__
@@ -178,7 +178,7 @@ def utf_16_le_encode(string, errors=None):
178178

179179
@__builtin__
180180
def utf_16_le_decode(string, errors=None, final=False):
181-
return __truffle_decode(string, "utf-16-le", errors)
181+
return __truffle_decode(string, "utf-16-le", errors, final)
182182

183183

184184
@__builtin__
@@ -188,7 +188,7 @@ def utf_16_be_encode(string, errors=None):
188188

189189
@__builtin__
190190
def utf_16_be_decode(string, errors=None, final=False):
191-
return __truffle_decode(string, "utf-16-be", errors)
191+
return __truffle_decode(string, "utf-16-be", errors, final)
192192

193193

194194
@__builtin__
@@ -203,7 +203,7 @@ def utf_32_encode(string, errors=None, byteorder=0):
203203

204204
@__builtin__
205205
def utf_32_decode(string, errors=None, final=False):
206-
return __truffle_decode(string, "utf-32", errors)
206+
return __truffle_decode(string, "utf-32", errors, final)
207207

208208

209209
@__builtin__
@@ -213,7 +213,7 @@ def utf_32_le_encode(string, errors=None):
213213

214214
@__builtin__
215215
def utf_32_le_decode(string, errors=None, final=False):
216-
return __truffle_decode(string, "utf-32-le", errors)
216+
return __truffle_decode(string, "utf-32-le", errors, final)
217217

218218

219219
@__builtin__
@@ -223,7 +223,7 @@ def utf_32_be_encode(string, errors=None):
223223

224224
@__builtin__
225225
def utf_32_be_decode(string, errors=None, final=False):
226-
return __truffle_decode(string, "utf-32-be", errors)
226+
return __truffle_decode(string, "utf-32-be", errors, final)
227227

228228

229229
@__builtin__

0 commit comments

Comments
 (0)