Skip to content

Commit 74e6330

Browse files
committed
Fix null pointer checks in getfunc
1 parent b2447c2 commit 74e6330

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/CFieldBuiltins.java

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -902,11 +902,11 @@ Object O_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @Suppres
902902
@Shared @Cached PointerNodes.ReadPointerNode readPointerNode,
903903
@Cached PointerNodes.ReadPythonObject readPythonObject,
904904
@Cached PRaiseNode raiseNode) {
905-
if (ptr.isNull()) {
905+
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
906+
if (valuePtr.isNull()) {
906907
throw raiseNode.raise(ValueError, ErrorMessages.PY_OBJ_IS_NULL);
907908
}
908-
Pointer value = readPointerNode.execute(inliningTarget, ptr);
909-
return readPythonObject.execute(inliningTarget, value);
909+
return readPythonObject.execute(inliningTarget, valuePtr);
910910
}
911911

912912
@Specialization(guards = "getfunc == c_get")
@@ -956,14 +956,13 @@ static Object z_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @
956956
@Shared @Cached PythonObjectFactory factory,
957957
@Shared @Cached PointerNodes.StrLenNode strLenNode,
958958
@Shared @Cached PointerNodes.ReadBytesNode readBytesNode) {
959-
if (!ptr.isNull()) {
960-
// ptr is a char**, we need to deref it to get char*
961-
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
962-
byte[] bytes = readBytesNode.execute(inliningTarget, valuePtr, strLenNode.execute(inliningTarget, valuePtr));
963-
return factory.createBytes(bytes);
964-
} else {
959+
// ptr is a char**, we need to deref it to get char*
960+
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
961+
if (valuePtr.isNull()) {
965962
return PNone.NONE;
966963
}
964+
byte[] bytes = readBytesNode.execute(inliningTarget, valuePtr, strLenNode.execute(inliningTarget, valuePtr));
965+
return factory.createBytes(bytes);
967966
}
968967

969968
@Specialization(guards = "getfunc == Z_get")
@@ -974,14 +973,13 @@ static Object Z_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @
974973
@Shared @Cached PointerNodes.ReadBytesNode readBytesNode,
975974
@Cached TruffleString.FromByteArrayNode fromByteArrayNode,
976975
@Cached TruffleString.SwitchEncodingNode switchEncodingNode) {
977-
if (!ptr.isNull()) {
978-
// ptr is a char**, we need to deref it to get char*
979-
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
980-
byte[] bytes = readBytesNode.execute(inliningTarget, valuePtr, wCsLenNode.execute(inliningTarget, valuePtr, size) * WCHAR_T_SIZE);
981-
return switchEncodingNode.execute(fromByteArrayNode.execute(bytes, WCHAR_T_ENCODING, false), TS_ENCODING);
982-
} else {
976+
// ptr is a char**, we need to deref it to get char*
977+
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
978+
if (valuePtr.isNull()) {
983979
return PNone.NONE;
984980
}
981+
byte[] bytes = readBytesNode.execute(inliningTarget, valuePtr, wCsLenNode.execute(inliningTarget, valuePtr, size) * WCHAR_T_SIZE);
982+
return switchEncodingNode.execute(fromByteArrayNode.execute(bytes, WCHAR_T_ENCODING, false), TS_ENCODING);
985983
}
986984

987985
@Specialization(guards = "getfunc == P_get")
@@ -990,10 +988,11 @@ static Object P_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @
990988
@Shared @Cached PointerNodes.ReadPointerNode readPointerNode,
991989
@Cached PointerNodes.GetPointerValueAsObjectNode getPointerValueAsObjectNode,
992990
@Shared @Cached PythonObjectFactory factory) {
993-
if (ptr.isNull()) {
991+
Pointer valuePtr = readPointerNode.execute(inliningTarget, ptr);
992+
if (valuePtr.isNull()) {
994993
return 0L;
995994
}
996-
Object p = getPointerValueAsObjectNode.execute(inliningTarget, readPointerNode.execute(inliningTarget, ptr));
995+
Object p = getPointerValueAsObjectNode.execute(inliningTarget, valuePtr);
997996
if (p instanceof Long) {
998997
long val = (long) p;
999998
return val < 0 ? factory.createInt(PInt.longToUnsignedBigInteger(val)) : val;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/CtypesModuleBuiltins.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,9 +1460,6 @@ static Object callGetFunc(Object restype, FFIType rtype, Object result, Object c
14601460
} else {
14611461
pointer = Pointer.nativeMemory(result);
14621462
}
1463-
if (pointer.isNull()) {
1464-
yield Pointer.NULL;
1465-
}
14661463
yield pointer.createReference();
14671464
}
14681465
};

0 commit comments

Comments
 (0)