Skip to content

Commit 12b5176

Browse files
committed
Add 'test_mmap' and make a basic set of tests pass.
1 parent d53d780 commit 12b5176

File tree

4 files changed

+220
-17
lines changed

4 files changed

+220
-17
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from test.support import (TESTFN, run_unittest, import_module, unlink,
2+
requires, _2G, _4G, gc_collect, cpython_only)
3+
import unittest
4+
import os
5+
import re
6+
import itertools
7+
import socket
8+
import sys
9+
import weakref
10+
11+
# Skip test if we can't import mmap.
12+
mmap = import_module('mmap')
13+
14+
PAGESIZE = mmap.PAGESIZE
15+
16+
class MmapTests(unittest.TestCase):
17+
18+
def setUp(self):
19+
if os.path.exists(TESTFN):
20+
os.unlink(TESTFN)
21+
22+
def tearDown(self):
23+
try:
24+
os.unlink(TESTFN)
25+
except OSError:
26+
pass
27+
28+
def test_basic(self):
29+
# Test mmap module on Unix systems and Windows
30+
31+
# Create a file to be mmap'ed.
32+
f = open(TESTFN, 'bw+')
33+
try:
34+
# Write 2 pages worth of data to the file
35+
f.write(b'\0'* PAGESIZE)
36+
f.write(b'foo')
37+
f.write(b'\0'* (PAGESIZE-3) )
38+
f.flush()
39+
m = mmap.mmap(f.fileno(), 2 * PAGESIZE)
40+
finally:
41+
f.close()
42+
43+
# Simple sanity checks
44+
45+
tp = str(type(m)) # SF bug 128713: segfaulted on Linux
46+
self.assertEqual(m.find(b'foo'), PAGESIZE)
47+
48+
self.assertEqual(len(m), 2*PAGESIZE)
49+
50+
self.assertEqual(m[0], 0)
51+
self.assertEqual(m[0:3], b'\0\0\0')
52+
53+
# Shouldn't crash on boundary (Issue #5292)
54+
self.assertRaises(IndexError, m.__getitem__, len(m))
55+
self.assertRaises(IndexError, m.__setitem__, len(m), b'\0')
56+
57+
# Modify the file's content
58+
m[0] = b'3'[0]
59+
m[PAGESIZE +3: PAGESIZE +3+3] = b'bar'
60+
61+
# Check that the modification worked
62+
self.assertEqual(m[0], b'3'[0])
63+
self.assertEqual(m[0:3], b'3\0\0')
64+
self.assertEqual(m[PAGESIZE-1 : PAGESIZE + 7], b'\0foobar\0')
65+
66+
m.flush()
67+
68+
# Test doing a regular expression match in an mmap'ed file
69+
match = re.search(b'[A-Za-z]+', m)
70+
if match is None:
71+
self.fail('regex match on mmap failed!')
72+
else:
73+
start, end = match.span(0)
74+
length = end - start
75+
76+
self.assertEqual(start, PAGESIZE)
77+
self.assertEqual(end, PAGESIZE + 6)
78+
79+
# test seeking around (try to overflow the seek implementation)
80+
m.seek(0,0)
81+
self.assertEqual(m.tell(), 0)
82+
m.seek(42,1)
83+
self.assertEqual(m.tell(), 42)
84+
m.seek(0,2)
85+
self.assertEqual(m.tell(), len(m))
86+
87+
# Try to seek to negative position...
88+
self.assertRaises(ValueError, m.seek, -1)
89+
90+
# Try to seek beyond end of mmap...
91+
self.assertRaises(ValueError, m.seek, 1, 2)
92+
93+
# Try to seek to negative position...
94+
self.assertRaises(ValueError, m.seek, -len(m)-1, 2)
95+
96+
# Try resizing map
97+
# try:
98+
# m.resize(512)
99+
# except SystemError:
100+
# # resize() not supported
101+
# # No messages are printed, since the output of this test suite
102+
# # would then be different across platforms.
103+
# pass
104+
# else:
105+
# # resize() is supported
106+
# self.assertEqual(len(m), 512)
107+
# # Check that we can no longer seek beyond the new size.
108+
# self.assertRaises(ValueError, m.seek, 513, 0)
109+
#
110+
# # Check that the underlying file is truncated too
111+
# # (bug #728515)
112+
# f = open(TESTFN, 'rb')
113+
# try:
114+
# f.seek(0, 2)
115+
# self.assertEqual(f.tell(), 512)
116+
# finally:
117+
# f.close()
118+
# self.assertEqual(m.size(), 512)
119+
120+
m.close()
121+
122+
123+
def test_bad_file_desc(self):
124+
# Try opening a bad file descriptor...
125+
self.assertRaises(OSError, mmap.mmap, -2, 4096)
126+
127+
128+
def test_anonymous(self):
129+
# anonymous mmap.mmap(-1, PAGE)
130+
m = mmap.mmap(-1, PAGESIZE)
131+
for x in range(PAGESIZE):
132+
self.assertEqual(m[x], 0,
133+
"anonymously mmap'ed contents should be zero")
134+
135+
for x in range(PAGESIZE):
136+
b = x & 0xff
137+
m[x] = b
138+
self.assertEqual(m[x], b)
139+
140+
def test_read_all(self):
141+
m = mmap.mmap(-1, 16)
142+
143+
# With no parameters, or None or a negative argument, reads all
144+
m.write(bytes(range(16)))
145+
m.seek(0)
146+
self.assertEqual(m.read(), bytes(range(16)))
147+
m.seek(8)
148+
self.assertEqual(m.read(), bytes(range(8, 16)))
149+
m.seek(16)
150+
self.assertEqual(m.read(), b'')
151+
m.seek(3)
152+
self.assertEqual(m.read(None), bytes(range(3, 16)))
153+
m.seek(4)
154+
self.assertEqual(m.read(-1), bytes(range(4, 16)))
155+
m.seek(5)
156+
self.assertEqual(m.read(-2), bytes(range(5, 16)))
157+
m.seek(9)
158+
self.assertEqual(m.read(-42), bytes(range(9, 16)))
159+
m.close()
160+
161+
162+
163+
def test_context_manager(self):
164+
with mmap.mmap(-1, 10) as m:
165+
self.assertFalse(m.closed)
166+
self.assertTrue(m.closed)
167+
168+
def test_context_manager_exception(self):
169+
# Test that the OSError gets passed through
170+
with self.assertRaises(Exception) as exc:
171+
with mmap.mmap(-1, 10) as m:
172+
raise OSError
173+
self.assertIsInstance(exc.exception, OSError,
174+
"wrong exception raised in context manager")
175+
self.assertTrue(m.closed, "context manager failed")
176+
177+
178+
179+
def test_main():
180+
#run_unittest(MmapTests, LargeMmapTests)
181+
run_unittest(MmapTests)
182+
183+
if __name__ == '__main__':
184+
test_main()

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MMapModuleBuiltins.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@
4444
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.ValueError;
4545

4646
import java.io.IOException;
47-
import java.nio.channels.FileChannel.MapMode;
4847
import java.nio.ByteBuffer;
49-
import java.nio.channels.ByteChannel;
48+
import java.nio.channels.FileChannel.MapMode;
5049
import java.nio.channels.SeekableByteChannel;
5150
import java.nio.file.StandardOpenOption;
5251
import java.util.HashSet;
@@ -67,7 +66,6 @@
6766
import com.oracle.truffle.api.dsl.NodeFactory;
6867
import com.oracle.truffle.api.dsl.Specialization;
6968
import com.oracle.truffle.api.profiles.BranchProfile;
70-
import com.oracle.truffle.api.profiles.ValueProfile;
7169

7270
@CoreFunctions(defineModule = "mmap")
7371
public class MMapModuleBuiltins extends PythonBuiltins {
@@ -92,13 +90,13 @@ public MMapModuleBuiltins() {
9290
@GenerateNodeFactory
9391
public abstract static class MMapNode extends PythonBuiltinNode {
9492

95-
private final ValueProfile classProfile = ValueProfile.createClassProfile();
9693
private final BranchProfile invalidLengthProfile = BranchProfile.create();
9794

98-
@Specialization(guards = {"fd < 0", "isNoValue(access)", "isNoValue(offset)"})
99-
PMMap doAnonymous(LazyPythonClass clazz, int fd, int length, Object tagname, @SuppressWarnings("unused") PNone access, @SuppressWarnings("unused") PNone offset) {
95+
@Specialization(guards = {"isAnonymous(fd)", "isNoValue(access)", "isNoValue(offset)"})
96+
PMMap doAnonymous(LazyPythonClass clazz, @SuppressWarnings("unused") int fd, int length, @SuppressWarnings("unused") Object tagname, @SuppressWarnings("unused") PNone access,
97+
@SuppressWarnings("unused") PNone offset) {
10098
checkLength(length);
101-
return new PMMap(clazz, new AnonymousMap(length), length, 0);
99+
return factory().createMMap(clazz, new AnonymousMap(length), length, 0);
102100
}
103101

104102
@Specialization(guards = {"fd >= 0", "isNoValue(access)", "isNoValue(offset)"})
@@ -108,14 +106,13 @@ PMMap doIt(LazyPythonClass clazz, int fd, int length, Object tagname, @SuppressW
108106

109107
// mmap(fileno, length, tagname=None, access=ACCESS_DEFAULT[, offset])
110108
@Specialization(guards = "fd >= 0")
111-
PMMap doFile(LazyPythonClass clazz, int fd, int length, @SuppressWarnings("unused") Object tagname, int access, long offset) {
109+
PMMap doFile(LazyPythonClass clazz, int fd, int length, @SuppressWarnings("unused") Object tagname, @SuppressWarnings("unused") int access, long offset) {
112110
checkLength(length);
113111

114112
String path = getContext().getResources().getFilePath(fd);
115113
TruffleFile truffleFile = getContext().getEnv().getTruffleFile(path);
116114

117115
// TODO(fa) correctly honor access flags
118-
MapMode mode = convertAccessToMapMode(access);
119116
Set<StandardOpenOption> options = new HashSet<>();
120117
options.add(StandardOpenOption.READ);
121118
options.add(StandardOpenOption.WRITE);
@@ -125,12 +122,27 @@ PMMap doFile(LazyPythonClass clazz, int fd, int length, @SuppressWarnings("unuse
125122
try {
126123
fileChannel = truffleFile.newByteChannel(options);
127124
fileChannel.position(offset);
128-
return new PMMap(PythonBuiltinClassType.PMMap, fileChannel, length, offset);
125+
return factory().createMMap(clazz, fileChannel, length, offset);
129126
} catch (IOException e) {
130127
throw raise(ValueError, "cannot mmap file");
131128
}
132129
}
133130

131+
@Specialization(guards = "isIllegal(fd)")
132+
@SuppressWarnings("unused")
133+
PMMap doAnonymous(LazyPythonClass clazz, int fd, Object length, Object tagname, PNone access, PNone offset) {
134+
throw raise(PythonBuiltinClassType.OSError);
135+
}
136+
137+
protected static boolean isAnonymous(int fd) {
138+
return fd == -1;
139+
}
140+
141+
protected static boolean isIllegal(int fd) {
142+
return fd < -1;
143+
}
144+
145+
@SuppressWarnings("unused")
134146
private MapMode convertAccessToMapMode(int access) {
135147
switch (access) {
136148
case 0:
@@ -189,7 +201,7 @@ public long position() throws IOException {
189201
}
190202

191203
public SeekableByteChannel position(long newPosition) throws IOException {
192-
if (newPosition < 0 || newPosition >= data.length) {
204+
if (newPosition < 0) {
193205
throw new IllegalArgumentException();
194206
}
195207
cur = (int) newPosition;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap/MMapBuiltins.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ int doSingle(VirtualFrame frame, PMMap self, Object idxObj,
246246
long oldPos = channel.position();
247247

248248
channel.position(idx);
249-
int res = readByteNode.execute(channel);
249+
int res = readByteNode.execute(channel) & 0xFF;
250250

251251
// restore position
252252
channel.position(oldPos);
@@ -458,18 +458,25 @@ int readByte(PMMap self,
458458
@TypeSystemReference(PythonArithmeticTypes.class)
459459
abstract static class ReadNode extends PythonBuiltinNode {
460460

461-
@Specialization(guards = "isNoValue(n)")
462-
PBytes read(PMMap self, @SuppressWarnings("unused") PNone n,
461+
@Specialization
462+
PBytes readUnlimited(PMMap self, @SuppressWarnings("unused") PNone n,
463463
@Cached("create()") ReadFromChannelNode readChannelNode) {
464+
// intentionally accept NO_VALUE and NONE; both mean that we read unlimited amount of
465+
// bytes
464466
ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), ReadFromChannelNode.MAX_READ);
465467
return factory().createBytes(res);
466468
}
467469

468470
@Specialization(guards = "!isNoValue(n)")
469471
PBytes read(PMMap self, Object n,
470472
@Cached("create()") ReadFromChannelNode readChannelNode,
471-
@Cached("create()") CastToIndexNode castToIndexNode) {
472-
ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), castToIndexNode.execute(n));
473+
@Cached("create()") CastToIndexNode castToIndexNode,
474+
@Cached("createBinaryProfile()") ConditionProfile negativeProfile) {
475+
int nread = castToIndexNode.execute(n);
476+
if (negativeProfile.profile(nread < 0)) {
477+
return readUnlimited(self, PNone.NO_VALUE, readChannelNode);
478+
}
479+
ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), nread);
473480
return factory().createBytes(res);
474481
}
475482

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/ChannelNodes.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ int readByte(ReadableByteChannel channel,
178178
if (readProfile.profile(read != 1)) {
179179
return handleError(channel);
180180
}
181-
return get(buf);
181+
return get(buf) & 0xFF;
182182
}
183183

184184
private int handleError(Channel channel) {

0 commit comments

Comments
 (0)