Skip to content

Commit 98f0ec9

Browse files
committed
Improve placing of TruffleBoundaries.
1 parent 72ea768 commit 98f0ec9

File tree

5 files changed

+72
-60
lines changed

5 files changed

+72
-60
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MMapModuleBuiltins.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6363
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
6464
import com.oracle.truffle.api.TruffleFile;
65+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6566
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6667
import com.oracle.truffle.api.dsl.NodeFactory;
6768
import com.oracle.truffle.api.dsl.Specialization;
@@ -113,21 +114,28 @@ PMMap doFile(LazyPythonClass clazz, int fd, int length, @SuppressWarnings("unuse
113114
TruffleFile truffleFile = getContext().getEnv().getTruffleFile(path);
114115

115116
// TODO(fa) correctly honor access flags
116-
Set<StandardOpenOption> options = new HashSet<>();
117-
options.add(StandardOpenOption.READ);
118-
options.add(StandardOpenOption.WRITE);
117+
Set<StandardOpenOption> options = set(StandardOpenOption.READ, StandardOpenOption.WRITE);
119118

120119
// we create a new channel otherwise we cannot guarantee that the cursor is exclusive
121120
SeekableByteChannel fileChannel;
122121
try {
123122
fileChannel = truffleFile.newByteChannel(options);
124-
fileChannel.position(offset);
123+
position(fileChannel, offset);
125124
return factory().createMMap(clazz, fileChannel, length, offset);
126125
} catch (IOException e) {
127126
throw raise(ValueError, "cannot mmap file");
128127
}
129128
}
130129

130+
@TruffleBoundary
131+
private static Set<StandardOpenOption> set(StandardOpenOption... options) {
132+
Set<StandardOpenOption> s = new HashSet<>();
133+
for (StandardOpenOption o : options) {
134+
s.add(o);
135+
}
136+
return s;
137+
}
138+
131139
@Specialization(guards = "isIllegal(fd)")
132140
@SuppressWarnings("unused")
133141
PMMap doAnonymous(LazyPythonClass clazz, int fd, Object length, Object tagname, PNone access, PNone offset) {
@@ -164,6 +172,10 @@ private void checkLength(int length) {
164172
}
165173
}
166174

175+
@TruffleBoundary
176+
private static void position(SeekableByteChannel ch, long offset) throws IOException {
177+
ch.position(offset);
178+
}
167179
}
168180

169181
private static class AnonymousMap implements SeekableByteChannel {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/SequenceStorageNodes.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@
128128
import com.oracle.graal.python.runtime.sequence.storage.TypedSequenceStorage;
129129
import com.oracle.truffle.api.CompilerDirectives;
130130
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
131-
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
132131
import com.oracle.truffle.api.dsl.Cached;
133132
import com.oracle.truffle.api.dsl.Fallback;
134133
import com.oracle.truffle.api.dsl.ImportStatic;
@@ -1745,7 +1744,6 @@ byte[] doFallback(@SuppressWarnings("unused") SequenceStorage s) {
17451744
throw raise(TypeError, "expected a bytes-like object");
17461745
}
17471746

1748-
@TruffleBoundary(transferToInterpreterOnException = false)
17491747
private static byte[] exactCopy(byte[] barr, int len) {
17501748
return Arrays.copyOf(barr, len);
17511749
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap/MMapBuiltins.java

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,19 @@
113113
@CoreFunctions(extendClasses = PythonBuiltinClassType.PMMap)
114114
public class MMapBuiltins extends PythonBuiltins {
115115

116-
protected interface ByteReadingNode {
116+
protected interface MMapBaseNode {
117+
@TruffleBoundary
118+
default long position(SeekableByteChannel ch) throws IOException {
119+
return ch.position();
120+
}
121+
122+
@TruffleBoundary
123+
default void position(SeekableByteChannel ch, long offset) throws IOException {
124+
ch.position(offset);
125+
}
126+
}
127+
128+
protected interface ByteReadingNode extends MMapBaseNode {
117129

118130
default ReadByteFromChannelNode createValueError() {
119131
return ReadByteFromChannelNode.create(() -> new ChannelNodes.ReadByteErrorHandler() {
@@ -137,7 +149,7 @@ public int execute(Channel channel) {
137149
}
138150
}
139151

140-
protected interface ByteWritingNode {
152+
protected interface ByteWritingNode extends MMapBaseNode {
141153

142154
default WriteByteToChannelNode createValueError() {
143155
return WriteByteToChannelNode.create(() -> new ChannelNodes.WriteByteErrorHandler() {
@@ -243,13 +255,13 @@ int doSingle(VirtualFrame frame, PMMap self, Object idxObj,
243255
long idx = i < 0 ? i + len : i;
244256

245257
// save current position
246-
long oldPos = channel.position();
258+
long oldPos = position(channel);
247259

248-
channel.position(idx);
260+
position(channel, idx);
249261
int res = readByteNode.execute(channel) & 0xFF;
250262

251263
// restore position
252-
channel.position(oldPos);
264+
position(channel, oldPos);
253265

254266
return res;
255267

@@ -268,13 +280,13 @@ Object doSlice(VirtualFrame frame, PMMap self, PSlice idx,
268280
SeekableByteChannel channel = self.getChannel();
269281

270282
// save current position
271-
long oldPos = channel.position();
283+
long oldPos = position(channel);
272284

273-
channel.position(info.start);
285+
position(channel, info.start);
274286
ByteSequenceStorage s = readNode.execute(channel, info.length);
275287

276288
// restore position
277-
channel.position(oldPos);
289+
position(channel, oldPos);
278290

279291
return factory().createBytes(s);
280292
} catch (IOException e) {
@@ -307,13 +319,13 @@ PNone doSingle(VirtualFrame frame, PMMap self, Object idxObj, Object val,
307319
}
308320

309321
// save current position
310-
long oldPos = channel.position();
322+
long oldPos = position(channel);
311323

312-
channel.position(idx);
324+
position(channel, idx);
313325
writeByteNode.execute(channel, castToByteNode.execute(val));
314326

315327
// restore position
316-
channel.position(oldPos);
328+
position(channel, oldPos);
317329

318330
return PNone.NONE;
319331

@@ -339,13 +351,13 @@ PNone doSlice(VirtualFrame frame, PMMap self, PSlice idx, PIBytesLike val,
339351
}
340352

341353
// save current position
342-
long oldPos = channel.position();
354+
long oldPos = position(channel);
343355

344-
channel.position(info.start);
356+
position(channel, info.start);
345357
writeNode.execute(channel, getStorageNode.execute(val), info.length);
346358

347359
// restore position
348-
channel.position(oldPos);
360+
position(channel, oldPos);
349361

350362
return PNone.NONE;
351363

@@ -428,13 +440,13 @@ long size(VirtualFrame frame, PMMap self,
428440

429441
@Builtin(name = "tell", fixedNumOfPositionalArgs = 1)
430442
@GenerateNodeFactory
431-
abstract static class TellNode extends PythonBuiltinNode {
443+
abstract static class TellNode extends PythonBuiltinNode implements ByteReadingNode {
432444
@Specialization
433445
long readline(VirtualFrame frame, PMMap self) {
434446

435447
try {
436448
SeekableByteChannel channel = self.getChannel();
437-
return channel.position() - self.getOffset();
449+
return position(channel) - self.getOffset();
438450
} catch (IOException e) {
439451
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
440452
}
@@ -484,7 +496,7 @@ PBytes read(PMMap self, Object n,
484496

485497
@Builtin(name = "readline", fixedNumOfPositionalArgs = 1)
486498
@GenerateNodeFactory
487-
abstract static class ReadlineNode extends PythonBuiltinNode {
499+
abstract static class ReadlineNode extends PythonUnaryBuiltinNode implements ByteReadingNode {
488500

489501
@Specialization
490502
Object readline(PMMap self,
@@ -504,7 +516,7 @@ Object readline(PMMap self,
504516
appendNode.execute(res, b);
505517
} else {
506518
// recover correct position (i.e. number of remaining bytes in buffer)
507-
channel.position(channel.position() - buf.remaining() - 1);
519+
position(channel, position(channel) - buf.remaining() - 1);
508520
break outer;
509521
}
510522
}
@@ -550,7 +562,7 @@ int writeMemoryview(PMMap self, PMemoryView memoryView,
550562
@Builtin(name = "seek", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
551563
@GenerateNodeFactory
552564
@TypeSystemReference(PythonArithmeticTypes.class)
553-
abstract static class SeekNode extends PythonBuiltinNode {
565+
abstract static class SeekNode extends PythonBuiltinNode implements MMapBaseNode {
554566
@Child private CastToIndexNode castToLongNode;
555567

556568
private final BranchProfile errorProfile = BranchProfile.create();
@@ -590,24 +602,14 @@ Object seek(VirtualFrame frame, PMMap self, long dist, Object how) {
590602
errorProfile.enter();
591603
throw raise(PythonBuiltinClassType.ValueError, "seek out of range");
592604
}
593-
doSeek(channel, where);
605+
position(channel, where);
594606
return PNone.NONE;
595607
} catch (IOException e) {
596608
errorProfile.enter();
597609
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
598610
}
599611
}
600612

601-
@TruffleBoundary(allowInlining = true)
602-
private static long position(SeekableByteChannel channel) throws IOException {
603-
return channel.position();
604-
}
605-
606-
@TruffleBoundary(allowInlining = true)
607-
private static void doSeek(SeekableByteChannel channel, long where) throws IOException {
608-
channel.position(where);
609-
}
610-
611613
private int castToInt(Object val) {
612614
if (castToLongNode == null) {
613615
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -653,7 +655,7 @@ long find(PMMap primary, PIBytesLike sub, Object starting, Object ending,
653655
// TODO implement a more efficient algorithm
654656
outer: for (long i = start; i < end; i++) {
655657
// TODO(fa) don't seek but use circular buffer
656-
channel.position(i);
658+
position(channel, i);
657659
for (int j = 0; j < len2; j++) {
658660
int hb = readByteNode.execute(channel);
659661
int nb = getGetRightItemNode().executeInt(needle, j);
@@ -682,7 +684,7 @@ long find(PMMap primary, int sub, Object starting, @SuppressWarnings("unused") O
682684
long start = s < 0 ? s + len1 : s;
683685
long end = Math.max(e < 0 ? e + len1 : e, len1);
684686

685-
channel.position(start);
687+
position(channel, start);
686688

687689
for (long i = start; i < end; i++) {
688690
int hb = readByteNode.execute(channel);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/ChannelNodes.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,64 +78,64 @@ public abstract static class WriteByteErrorHandler extends PNodeWithContext {
7878
public abstract void execute(Channel channel, byte b);
7979
}
8080

81-
abstract static class ReadFromChannelBaseNode extends PNodeWithContext {
82-
83-
private final BranchProfile gotException = BranchProfile.create();
84-
81+
protected interface ChannelBaseNode {
8582
@TruffleBoundary(allowInlining = true)
86-
protected static ByteBuffer allocate(int n) {
83+
default ByteBuffer allocate(int n) {
8784
return ByteBuffer.allocate(n);
8885
}
8986

9087
@TruffleBoundary(transferToInterpreterOnException = false)
91-
protected static long availableSize(SeekableByteChannel channel) throws IOException {
88+
default long availableSize(SeekableByteChannel channel) throws IOException {
9289
return channel.size() - channel.position();
9390
}
91+
}
92+
93+
abstract static class ReadFromChannelBaseNode extends PNodeWithContext implements ChannelBaseNode {
94+
95+
private final BranchProfile gotException = BranchProfile.create();
9496

9597
@TruffleBoundary(allowInlining = true)
9698
protected static byte[] getByteBufferArray(ByteBuffer dst) {
9799
return dst.array();
98100
}
99101

100-
@TruffleBoundary(allowInlining = true)
101102
protected int readIntoBuffer(ReadableByteChannel readableChannel, ByteBuffer dst) {
102103
try {
103-
return readableChannel.read(dst);
104+
return read(readableChannel, dst);
104105
} catch (IOException e) {
105106
gotException.enter();
106107
throw raise(OSError, e);
107108
}
108109
}
110+
111+
@TruffleBoundary(allowInlining = true, transferToInterpreterOnException = false)
112+
private static int read(ReadableByteChannel readableChannel, ByteBuffer dst) throws IOException {
113+
return readableChannel.read(dst);
114+
}
109115
}
110116

111-
abstract static class WriteToChannelBaseNode extends PNodeWithContext {
117+
abstract static class WriteToChannelBaseNode extends PNodeWithContext implements ChannelBaseNode {
112118

113119
private final BranchProfile gotException = BranchProfile.create();
114120

115-
@TruffleBoundary(allowInlining = true)
116-
protected static ByteBuffer allocate(int n) {
117-
return ByteBuffer.allocate(n);
118-
}
119-
120-
@TruffleBoundary(transferToInterpreterOnException = false)
121-
protected static long availableSize(SeekableByteChannel channel) throws IOException {
122-
return channel.size() - channel.position();
123-
}
124-
125121
@TruffleBoundary(allowInlining = true)
126122
protected static byte[] getByteBufferArray(ByteBuffer dst) {
127123
return dst.array();
128124
}
129125

130-
@TruffleBoundary(allowInlining = true)
131126
protected int writeFromBuffer(WritableByteChannel writableChannel, ByteBuffer src) {
132127
try {
133-
return writableChannel.write(src);
128+
return write(writableChannel, src);
134129
} catch (IOException e) {
135130
gotException.enter();
136131
throw raise(OSError, e);
137132
}
138133
}
134+
135+
@TruffleBoundary(allowInlining = true, transferToInterpreterOnException = false)
136+
private static int write(WritableByteChannel writableChannel, ByteBuffer src) throws IOException {
137+
return writableChannel.write(src);
138+
}
139139
}
140140

141141
public abstract static class ReadFromChannelNode extends ReadFromChannelBaseNode {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PosixResources.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public Channel getFileChannel(int fd) {
105105
return null;
106106
}
107107

108-
@TruffleBoundary(allowInlining = true)
108+
@TruffleBoundary
109109
public String getFilePath(int fd) {
110110
if (filePaths.size() > fd) {
111111
return filePaths.get(fd);

0 commit comments

Comments
 (0)