Skip to content

Commit c1059e8

Browse files
committed
refactor zlib compressobj, fix raw stream compression
1 parent 5697731 commit c1059e8

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ZLibModuleBuiltins.java

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@
4141

4242
package com.oracle.graal.python.builtins.modules;
4343

44+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
45+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ZLibError;
46+
47+
import java.io.ByteArrayOutputStream;
48+
import java.util.List;
49+
import java.util.zip.Adler32;
50+
import java.util.zip.CRC32;
51+
import java.util.zip.DataFormatException;
52+
import java.util.zip.Deflater;
53+
import java.util.zip.Inflater;
54+
4455
import com.oracle.graal.python.builtins.Builtin;
4556
import com.oracle.graal.python.builtins.CoreFunctions;
4657
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
@@ -56,12 +67,9 @@
5667
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
5768
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
5869
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
59-
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6070
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
6171
import com.oracle.graal.python.nodes.util.CastToIntegerFromIntNode;
6272
import com.oracle.graal.python.runtime.PythonCore;
63-
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
64-
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ZLibError;
6573
import com.oracle.truffle.api.CompilerDirectives;
6674
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6775
import com.oracle.truffle.api.dsl.Cached;
@@ -73,14 +81,6 @@
7381
import com.oracle.truffle.api.interop.ForeignAccess;
7482
import com.oracle.truffle.api.interop.TruffleObject;
7583
import com.oracle.truffle.api.profiles.ConditionProfile;
76-
import java.io.ByteArrayOutputStream;
77-
import java.util.Arrays;
78-
import java.util.List;
79-
import java.util.zip.Adler32;
80-
import java.util.zip.CRC32;
81-
import java.util.zip.DataFormatException;
82-
import java.util.zip.Deflater;
83-
import java.util.zip.Inflater;
8484

8585
@CoreFunctions(defineModule = ZLibModuleBuiltins.MODULE_NAME)
8686
public class ZLibModuleBuiltins extends PythonBuiltins {
@@ -386,10 +386,20 @@ abstract static class DeflateInitNode extends PythonBuiltinNode {
386386
@Specialization
387387
@TruffleBoundary
388388
Object deflateInit(int level, int method, int wbits, int memLevel, int strategy, Object zdict) {
389+
Deflater deflater;
390+
if (wbits < 0) {
391+
deflater = new Deflater(level, true);
392+
// generate a RAW stream
393+
} else if (wbits >= 25) {
394+
// include gzip container
395+
throw raise(PythonBuiltinClassType.NotImplementedError, "gzip containers");
396+
} else {
397+
deflater = new Deflater(level, true);
398+
}
399+
389400
if (method != DEFLATED) {
390401
throw raise(PythonBuiltinClassType.ValueError, "only DEFLATED (%d) allowed as method, got %d", DEFLATED, method);
391402
}
392-
Deflater deflater = new Deflater(level);
393403
deflater.setStrategy(strategy);
394404
if (zdict instanceof String) {
395405
deflater.setDictionary(((String) zdict).getBytes());
@@ -420,26 +430,28 @@ abstract static class DeflateCompress extends PythonTernaryBuiltinNode {
420430
@Specialization
421431
@TruffleBoundary
422432
Object deflateCompress(DeflaterWrapper stream, PIBytesLike pb, int mode) {
433+
ByteArrayOutputStream baos = new ByteArrayOutputStream();
423434
byte[] data = toBytes.execute(pb);
435+
byte[] result = new byte[DEF_BUF_SIZE];
436+
424437
stream.deflater.setInput(data);
425-
byte[] result = new byte[data.length];
426-
int bytesWritten = stream.deflater.deflate(result, 0, result.length, mode);
427-
while (bytesWritten > 0 && bytesWritten >= result.length) {
428-
result = Arrays.copyOf(result, bytesWritten * 2);
429-
bytesWritten += stream.deflater.deflate(result, bytesWritten, result.length - bytesWritten, mode);
438+
int deflateMode = mode;
439+
if (mode == Z_FINISH) {
440+
deflateMode = Z_SYNC_FLUSH;
441+
stream.deflater.finish();
430442
}
431-
return factory().createBytes(result);
432-
}
433-
}
434443

435-
@Builtin(name = "zlib_deflateEnd", fixedNumOfPositionalArgs = 1)
436-
@GenerateNodeFactory
437-
abstract static class DeflateEnd extends PythonUnaryBuiltinNode {
438-
@Specialization
439-
@TruffleBoundary
440-
PNone deflateEnd(DeflaterWrapper stream) {
441-
stream.deflater.end();
442-
return PNone.NONE;
444+
int bytesWritten = result.length;
445+
while (bytesWritten == result.length) {
446+
bytesWritten = stream.deflater.deflate(result, 0, result.length, deflateMode);
447+
baos.write(result, 0, bytesWritten);
448+
}
449+
450+
if (mode == Z_FINISH) {
451+
stream.deflater.end();
452+
}
453+
454+
return factory().createBytes(baos.toByteArray());
443455
}
444456
}
445457

graalpython/lib-graalpython/zlib.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,9 @@ def flush(self, mode=Z_FINISH):
9797
used after calling the flush() method. Otherwise, more data
9898
can still be compressed.
9999
"""
100-
if mode == Z_FINISH:
101-
finish = True
102-
mode = Z_FULL_FLUSH
103100
if self.stream:
104101
result = zlib_deflateCompress(self.stream, b"", mode)
105-
if finish:
106-
zlib_deflateEnd(self.stream)
102+
if mode == Z_FINISH:
107103
self.stream = None
108104
return result
109105
else:

0 commit comments

Comments
 (0)