68
68
import com .oracle .graal .python .builtins .PythonBuiltins ;
69
69
import com .oracle .graal .python .builtins .objects .PNone ;
70
70
import com .oracle .graal .python .builtins .objects .bytes .PBytes ;
71
+ import com .oracle .graal .python .builtins .objects .bytes .PIBytesLike ;
71
72
import com .oracle .graal .python .builtins .objects .common .SequenceStorageNodes ;
72
73
import com .oracle .graal .python .builtins .objects .common .SequenceStorageNodes .NoGeneralizationNode ;
74
+ import com .oracle .graal .python .builtins .objects .common .SequenceStorageNodes .NormalizeIndexNode ;
73
75
import com .oracle .graal .python .builtins .objects .exception .OSErrorEnum ;
76
+ import com .oracle .graal .python .builtins .objects .ints .PInt ;
77
+ import com .oracle .graal .python .builtins .objects .mmap .MMapBuiltinsFactory .InternalLenNodeGen ;
78
+ import com .oracle .graal .python .builtins .objects .slice .PSlice ;
79
+ import com .oracle .graal .python .builtins .objects .slice .PSlice .SliceInfo ;
80
+ import com .oracle .graal .python .nodes .PNodeWithContext ;
74
81
import com .oracle .graal .python .nodes .SpecialMethodNames ;
75
82
import com .oracle .graal .python .nodes .call .special .LookupAndCallUnaryNode ;
76
83
import com .oracle .graal .python .nodes .function .PythonBuiltinBaseNode ;
81
88
import com .oracle .graal .python .nodes .truffle .PythonArithmeticTypes ;
82
89
import com .oracle .graal .python .nodes .util .CastToByteNode ;
83
90
import com .oracle .graal .python .nodes .util .CastToIndexNode ;
91
+ import com .oracle .graal .python .nodes .util .CastToJavaLongNode ;
84
92
import com .oracle .graal .python .nodes .util .ChannelNodes .ReadByteFromChannelNode ;
85
93
import com .oracle .graal .python .nodes .util .ChannelNodes .ReadFromChannelNode ;
86
94
import com .oracle .graal .python .runtime .sequence .storage .ByteSequenceStorage ;
95
+ import com .oracle .graal .python .runtime .sequence .storage .SequenceStorage ;
96
+ import com .oracle .truffle .api .CompilerDirectives ;
87
97
import com .oracle .truffle .api .CompilerDirectives .TruffleBoundary ;
88
98
import com .oracle .truffle .api .dsl .Cached ;
89
99
import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
90
100
import com .oracle .truffle .api .dsl .NodeFactory ;
91
101
import com .oracle .truffle .api .dsl .Specialization ;
92
102
import com .oracle .truffle .api .dsl .TypeSystemReference ;
93
103
import com .oracle .truffle .api .frame .VirtualFrame ;
94
- import com .oracle .truffle .api .profiles .ConditionProfile ;
104
+ import com .oracle .truffle .api .profiles .BranchProfile ;
95
105
96
106
@ CoreFunctions (extendClasses = PythonBuiltinClassType .PMMap )
97
107
public class MMapBuiltins extends PythonBuiltins {
@@ -163,7 +173,59 @@ abstract static class ReprNode extends StrNode {
163
173
164
174
@ Builtin (name = __GETITEM__ , fixedNumOfPositionalArgs = 2 )
165
175
@ GenerateNodeFactory
166
- abstract static class GetItemNode extends PythonBinaryBuiltinNode {
176
+ abstract static class GetItemNode extends PythonBuiltinNode {
177
+
178
+ @ Specialization (guards = "!isPSlice(idxObj)" )
179
+ int doSingle (VirtualFrame frame , PMMap self , Object idxObj ,
180
+ @ Cached ("create()" ) ReadByteFromChannelNode readByteNode ,
181
+ @ Cached ("createExact()" ) CastToJavaLongNode castToLongNode ,
182
+ @ Cached ("create()" ) InternalLenNode lenNode ) {
183
+
184
+ try {
185
+ long i = castToLongNode .execute (idxObj );
186
+ long len = lenNode .execute (frame , self );
187
+ SeekableByteChannel channel = self .getChannel ();
188
+ long idx = i < 0 ? i + len : i ;
189
+
190
+ // save current position
191
+ long oldPos = channel .position ();
192
+
193
+ channel .position (idx );
194
+ int res = readByteNode .execute (channel );
195
+
196
+ // restore position
197
+ channel .position (oldPos );
198
+
199
+ return res ;
200
+
201
+ } catch (IOException e ) {
202
+ throw raise (PythonBuiltinClassType .OSError , e .getMessage ());
203
+ }
204
+ }
205
+
206
+ @ Specialization
207
+ Object doSlice (VirtualFrame frame , PMMap self , PSlice idx ,
208
+ @ Cached ("create()" ) ReadFromChannelNode readNode ,
209
+ @ Cached ("create()" ) InternalLenNode lenNode ) {
210
+ try {
211
+ long len = lenNode .execute (frame , self );
212
+ SliceInfo info = idx .computeIndices (PInt .intValueExact (len ));
213
+ SeekableByteChannel channel = self .getChannel ();
214
+
215
+ // save current position
216
+ long oldPos = channel .position ();
217
+
218
+ channel .position (info .start );
219
+ ByteSequenceStorage s = readNode .execute (channel , info .length );
220
+
221
+ // restore position
222
+ channel .position (oldPos );
223
+
224
+ return factory ().createBytes (s );
225
+ } catch (IOException e ) {
226
+ throw raise (PythonBuiltinClassType .OSError , e .getMessage ());
227
+ }
228
+ }
167
229
}
168
230
169
231
@ Builtin (name = SpecialMethodNames .__SETITEM__ , fixedNumOfPositionalArgs = 3 )
@@ -176,12 +238,9 @@ abstract static class SetItemNode extends PythonTernaryBuiltinNode {
176
238
@ GenerateNodeFactory
177
239
public abstract static class LenNode extends PythonBuiltinNode {
178
240
@ Specialization
179
- long len (VirtualFrame frame , PMMap self ) {
180
- try {
181
- return self .getChannel ().size ();
182
- } catch (IOException e ) {
183
- throw raiseOSError (frame , OSErrorEnum .EIO , e .getMessage ());
184
- }
241
+ long len (VirtualFrame frame , PMMap self ,
242
+ @ Cached ("create()" ) InternalLenNode lenNode ) {
243
+ return lenNode .execute (frame , self );
185
244
}
186
245
}
187
246
@@ -227,15 +286,8 @@ abstract static class SizeNode extends PythonBuiltinNode {
227
286
228
287
@ Specialization
229
288
long size (VirtualFrame frame , PMMap self ,
230
- @ Cached ("createBinaryProfile()" ) ConditionProfile profile ) {
231
- if (profile .profile (self .getLength () == 0 )) {
232
- try {
233
- return self .getChannel ().size () - self .getOffset ();
234
- } catch (IOException e ) {
235
- throw raiseOSError (frame , OSErrorEnum .EIO , e .getMessage ());
236
- }
237
- }
238
- return self .getLength ();
289
+ @ Cached ("create()" ) InternalLenNode lenNode ) {
290
+ return lenNode .execute (frame , self );
239
291
}
240
292
}
241
293
@@ -378,4 +430,143 @@ private Object doSeek(PMMap self, long dist, int how) throws IOException {
378
430
}
379
431
}
380
432
433
+ @ Builtin (name = "find" , minNumOfPositionalArgs = 2 , maxNumOfPositionalArgs = 4 )
434
+ @ GenerateNodeFactory
435
+ @ TypeSystemReference (PythonArithmeticTypes .class )
436
+ public abstract static class FindNode extends PythonBuiltinNode {
437
+
438
+ @ Child private NormalizeIndexNode normalizeIndexNode ;
439
+ @ Child private SequenceStorageNodes .GetItemNode getLeftItemNode ;
440
+ @ Child private SequenceStorageNodes .GetItemNode getRightItemNode ;
441
+
442
+ public abstract long execute (PMMap bytes , Object sub , Object starting , Object ending );
443
+
444
+ @ Specialization
445
+ long find (PMMap primary , PIBytesLike sub , Object starting , Object ending ,
446
+ @ Cached ("create()" ) ReadByteFromChannelNode readByteNode ) {
447
+ try {
448
+ SeekableByteChannel channel = primary .getChannel ();
449
+ long len1 = channel .size ();
450
+
451
+ SequenceStorage needle = sub .getSequenceStorage ();
452
+ int len2 = needle .length ();
453
+
454
+ long s = castToLong (starting , 0 );
455
+ long e = castToLong (ending , len1 );
456
+
457
+ long start = s < 0 ? s + len1 : s ;
458
+ long end = e < 0 ? e + len1 : e ;
459
+
460
+ if (start >= len1 || len1 < len2 ) {
461
+ return -1 ;
462
+ } else if (end > len1 ) {
463
+ end = len1 ;
464
+ }
465
+
466
+ // TODO implement a more efficient algorithm
467
+ outer : for (long i = start ; i < end ; i ++) {
468
+ // TODO(fa) don't seek but use circular buffer
469
+ channel .position (i );
470
+ for (int j = 0 ; j < len2 ; j ++) {
471
+ int hb = readByteNode .execute (channel );
472
+ int nb = getGetRightItemNode ().executeInt (needle , j );
473
+ if (nb != hb || i + j >= end ) {
474
+ continue outer ;
475
+ }
476
+ }
477
+ return i ;
478
+ }
479
+ return -1 ;
480
+ } catch (IOException e ) {
481
+ throw raise (PythonBuiltinClassType .OSError , e .getMessage ());
482
+ }
483
+ }
484
+
485
+ @ Specialization
486
+ long find (PMMap primary , int sub , Object starting , @ SuppressWarnings ("unused" ) Object ending ,
487
+ @ Cached ("create()" ) ReadByteFromChannelNode readByteNode ) {
488
+ try {
489
+ SeekableByteChannel channel = primary .getChannel ();
490
+ long len1 = channel .size ();
491
+
492
+ long s = castToLong (starting , 0 );
493
+ long e = castToLong (ending , len1 );
494
+
495
+ long start = s < 0 ? s + len1 : s ;
496
+ long end = Math .max (e < 0 ? e + len1 : e , len1 );
497
+
498
+ channel .position (start );
499
+
500
+ for (long i = start ; i < end ; i ++) {
501
+ int hb = readByteNode .execute (channel );
502
+ if (hb == sub ) {
503
+ return i ;
504
+ }
505
+ }
506
+ return -1 ;
507
+ } catch (IOException e ) {
508
+ throw raise (PythonBuiltinClassType .OSError , e .getMessage ());
509
+ }
510
+ }
511
+
512
+ // TODO(fa): use node
513
+ private static long castToLong (Object obj , long defaultVal ) {
514
+ if (obj instanceof Integer || obj instanceof Long ) {
515
+ return ((Number ) obj ).longValue ();
516
+ } else if (obj instanceof PInt ) {
517
+ try {
518
+ return ((PInt ) obj ).longValueExact ();
519
+ } catch (ArithmeticException e ) {
520
+ return defaultVal ;
521
+ }
522
+ }
523
+ return defaultVal ;
524
+ }
525
+
526
+ private SequenceStorageNodes .GetItemNode getGetRightItemNode () {
527
+ if (getRightItemNode == null ) {
528
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
529
+ getRightItemNode = insert (SequenceStorageNodes .GetItemNode .create ());
530
+ }
531
+ return getRightItemNode ;
532
+ }
533
+ }
534
+
535
+ abstract static class InternalLenNode extends PNodeWithContext {
536
+
537
+ public abstract long execute (VirtualFrame frame , PMMap self );
538
+
539
+ @ Specialization (guards = "self.getLength() == 0" )
540
+ long doFull (VirtualFrame frame , PMMap self ,
541
+ @ Cached ("create()" ) BranchProfile profile ) {
542
+ try {
543
+ return self .getChannel ().size () - self .getOffset ();
544
+ } catch (IOException e ) {
545
+ profile .enter ();
546
+ throw raiseOSError (frame , OSErrorEnum .EIO , e .getMessage ());
547
+ }
548
+ }
549
+
550
+ @ Specialization (guards = "self.getLength() > 0" )
551
+ long doWindow (@ SuppressWarnings ("unused" ) VirtualFrame frame , PMMap self ) {
552
+ return self .getLength ();
553
+ }
554
+
555
+ @ Specialization
556
+ long doGeneric (VirtualFrame frame , PMMap self ) {
557
+ if (self .getLength () == 0 ) {
558
+ try {
559
+ return self .getChannel ().size () - self .getOffset ();
560
+ } catch (IOException e ) {
561
+ throw raiseOSError (frame , OSErrorEnum .EIO , e .getMessage ());
562
+ }
563
+ }
564
+ return self .getLength ();
565
+ }
566
+
567
+ public static InternalLenNode create () {
568
+ return InternalLenNodeGen .create ();
569
+ }
570
+ }
571
+
381
572
}
0 commit comments