diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java index 1b9afb41b68..22dbe21c187 100644 --- a/src/main/java/org/apache/sysds/parser/DataExpression.java +++ b/src/main/java/org/apache/sysds/parser/DataExpression.java @@ -1019,6 +1019,18 @@ public void validateExpression(HashMap ids, HashMap 0 && length <= Integer.MAX_VALUE) { + return new BufferedH5ByteReader(base, length, HDF5_READ_BUFFER_BYTES); + } + return base; + } + else if(is instanceof FileInputStream) { + FileChannel channel = ((FileInputStream) is).getChannel(); + length = channel.size(); + System.out.println("[HDF5] Using FileChannel-backed reader for " + sourceId + " (size=" + length + ")"); + if(HDF5_READ_USE_MMAP && length > 0) { + return new MappedH5ByteReader(channel, length, HDF5_READ_MAP_BYTES); + } + H5ByteReader base = new FileChannelByteReader(channel); + if(length > 0 && length <= Integer.MAX_VALUE) { + return new BufferedH5ByteReader(base, length, HDF5_READ_BUFFER_BYTES); + } + return base; + } + else { + byte[] cached = drainToByteArray(is); + System.out.println("[HDF5] Cached " + cached.length + " bytes into memory for " + sourceId); + return new BufferedH5ByteReader(new ByteArrayH5ByteReader(cached), cached.length, HDF5_READ_BUFFER_BYTES); + } + } + + private static byte[] drainToByteArray(InputStream is) throws IOException { + try(InputStream input = is; ByteArrayOutputStream bos = new ByteArrayOutputStream()) { + byte[] buff = new byte[8192]; + int len; + while((len = input.read(buff)) != -1) + bos.write(buff, 0, len); + return bos.toByteArray(); + } + } + private static MatrixBlock readHDF5MatrixFromHDFS(Path path, JobConf job, FileSystem fs, MatrixBlock dest, long rlen, long clen, int blen, String datasetName) throws IOException, DMLRuntimeException @@ -116,9 +202,8 @@ private static MatrixBlock readHDF5MatrixFromHDFS(Path path, JobConf job, //actual read of individual files long lnnz = 0; for(int fileNo = 0; fileNo < files.size(); fileNo++) { - BufferedInputStream bis = new BufferedInputStream(fs.open(files.get(fileNo)), - (int) (H5Constants.STATIC_HEADER_SIZE + (clen * rlen * 8))); - lnnz += readMatrixFromHDF5(bis, datasetName, dest, 0, rlen, clen, blen); + H5ByteReader byteReader = createByteReader(files.get(fileNo), fs); + lnnz += readMatrixFromHDF5(byteReader, datasetName, dest, 0, rlen, clen, blen); } //post processing dest.setNonZeros(lnnz); @@ -126,45 +211,155 @@ private static MatrixBlock readHDF5MatrixFromHDFS(Path path, JobConf job, return dest; } - public static long readMatrixFromHDF5(BufferedInputStream bis, String datasetName, MatrixBlock dest, + public static long readMatrixFromHDF5(H5ByteReader byteReader, String datasetName, MatrixBlock dest, int rl, long ru, long clen, int blen) { - bis.mark(0); long lnnz = 0; - H5RootObject rootObject = H5.H5Fopen(bis); + boolean skipNnz = HDF5_SKIP_NNZ && !dest.isInSparseFormat(); + if(HDF5_FORCE_DENSE && dest.isInSparseFormat()) { + dest.allocateDenseBlock(true); + skipNnz = HDF5_SKIP_NNZ; + if(HDF5_READ_TRACE) + System.out.println("[HDF5] Forcing dense output for dataset=" + datasetName); + } + H5RootObject rootObject = H5.H5Fopen(byteReader); H5ContiguousDataset contiguousDataset = H5.H5Dopen(rootObject, datasetName); - int[] dims = rootObject.getDimensions(); - int ncol = dims[1]; + int ncol = (int) rootObject.getCol(); + System.out.println("[HDF5] readMatrix dataset=" + datasetName + " dims=" + rootObject.getRow() + "x" + + rootObject.getCol() + " loop=[" + rl + "," + ru + ") dest=" + dest.getNumRows() + "x" + + dest.getNumColumns()); try { - double[] row = new double[ncol]; + double[] row = null; + double[] blockBuffer = null; + int[] ixBuffer = null; + double[] valBuffer = null; + long elemSize = contiguousDataset.getDataType().getDoubleDataType().getSize(); + long rowBytes = (long) ncol * elemSize; + if(rowBytes > Integer.MAX_VALUE) { + throw new DMLRuntimeException("HDF5 row size exceeds buffer capacity: " + rowBytes); + } + int blockRows = 1; + if(!contiguousDataset.isRankGt2() && rowBytes > 0) { + blockRows = (int) Math.max(1, HDF5_READ_BLOCK_BYTES / rowBytes); + } if( dest.isInSparseFormat() ) { SparseBlock sb = dest.getSparseBlock(); - for(int i = rl; i < ru; i++) { - H5.H5Dread(contiguousDataset, i, row); - int lnnzi = UtilFunctions.computeNnz(row, 0, (int)clen); - sb.allocate(i, lnnzi); //avoid row reallocations - for(int j = 0; j < ncol; j++) - sb.append(i, j, row[j]); //prunes zeros - lnnz += lnnzi; + if(contiguousDataset.isRankGt2()) { + row = new double[ncol]; + for(int i = rl; i < ru; i++) { + contiguousDataset.readRowDoubles(i, row, 0); + int lnnzi = UtilFunctions.computeNnz(row, 0, ncol); + sb.allocate(i, lnnzi); //avoid row reallocations + for(int j = 0; j < ncol; j++) + sb.append(i, j, row[j]); //prunes zeros + lnnz += lnnzi; + } + } + else { + ixBuffer = new int[ncol]; + valBuffer = new double[ncol]; + for(int i = rl; i < ru; ) { + int rowsToRead = (int) Math.min(blockRows, ru - i); + ByteBuffer buffer = contiguousDataset.getDataBuffer(i, rowsToRead); + DoubleBuffer db = buffer.order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer(); + int blockSize = rowsToRead * ncol; + if(blockBuffer == null || blockBuffer.length < blockSize) { + blockBuffer = new double[blockSize]; + } + db.get(blockBuffer, 0, blockSize); + for(int r = 0; r < rowsToRead; r++, i++) { + int base = r * ncol; + int lnnzi = 0; + for(int j = 0; j < ncol; j++) { + double v = blockBuffer[base + j]; + if(v != 0) { + ixBuffer[lnnzi] = j; + valBuffer[lnnzi] = v; + lnnzi++; + } + } + sb.allocate(i, lnnzi); //avoid row reallocations + for(int k = 0; k < lnnzi; k++) { + sb.append(i, ixBuffer[k], valBuffer[k]); + } + lnnz += lnnzi; + } + } } } else { DenseBlock denseBlock = dest.getDenseBlock(); - for(int i = rl; i < ru; i++) { - H5.H5Dread(contiguousDataset, i, row); - for(int j = 0; j < ncol; j++) { - if(row[j] != 0) { - denseBlock.set(i, j, row[j]); - lnnz++; + boolean fastDense = denseBlock.isNumeric(ValueType.FP64) + && !(denseBlock instanceof DenseBlockFP64DEDUP) + && !(denseBlock instanceof DenseBlockLFP64DEDUP); + if(contiguousDataset.isRankGt2()) { + row = new double[ncol]; + for(int i = rl; i < ru; i++) { + if(fastDense) { + double[] destRow = denseBlock.values(i); + int destPos = denseBlock.pos(i); + contiguousDataset.readRowDoubles(i, destRow, destPos); + if(!skipNnz) + lnnz += UtilFunctions.computeNnz(destRow, destPos, ncol); + } + else { + contiguousDataset.readRowDoubles(i, row, 0); + denseBlock.set(i, row); + if(!skipNnz) + lnnz += UtilFunctions.computeNnz(row, 0, ncol); + } + } + } + else { + boolean contiguousDense = fastDense && denseBlock.isContiguous(); + double[] destAll = contiguousDense ? denseBlock.values(0) : null; + for(int i = rl; i < ru; ) { + int rowsToRead = (int) Math.min(blockRows, ru - i); + ByteBuffer buffer = contiguousDataset.getDataBuffer(i, rowsToRead); + DoubleBuffer db = buffer.order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer(); + int blockSize = rowsToRead * ncol; + if(contiguousDense) { + int destPos = denseBlock.pos(i); + db.get(destAll, destPos, blockSize); + if(!skipNnz) + lnnz += UtilFunctions.computeNnz(destAll, destPos, blockSize); + i += rowsToRead; + continue; + } + if(fastDense) { + if(blockBuffer == null || blockBuffer.length < blockSize) { + blockBuffer = new double[blockSize]; + } + db.get(blockBuffer, 0, blockSize); + for(int r = 0; r < rowsToRead; r++, i++) { + double[] destRow = denseBlock.values(i); + int destPos = denseBlock.pos(i); + System.arraycopy(blockBuffer, r * ncol, destRow, destPos, ncol); + } + if(!skipNnz) + lnnz += UtilFunctions.computeNnz(blockBuffer, 0, blockSize); + continue; + } + for(int r = 0; r < rowsToRead; r++, i++) { + if(row == null) { + row = new double[ncol]; + } + db.get(row, 0, ncol); + denseBlock.set(i, row); + if(!skipNnz) + lnnz += UtilFunctions.computeNnz(row, 0, ncol); } } } } } finally { - IOUtilFunctions.closeSilently(bis); + rootObject.close(); + } + if(skipNnz) { + lnnz = Math.multiplyExact(ru - rl, clen); } return lnnz; } @@ -175,17 +370,287 @@ public static MatrixBlock computeHDF5Size(List files, FileSystem fs, Strin int nrow = 0; int ncol = 0; for(int fileNo = 0; fileNo < files.size(); fileNo++) { - BufferedInputStream bis = new BufferedInputStream(fs.open(files.get(fileNo))); - H5RootObject rootObject = H5.H5Fopen(bis); + H5ByteReader byteReader = createByteReader(files.get(fileNo), fs); + H5RootObject rootObject = H5.H5Fopen(byteReader); H5.H5Dopen(rootObject, datasetName); - int[] dims = rootObject.getDimensions(); - nrow += dims[0]; - ncol += dims[1]; + nrow += (int) rootObject.getRow(); + ncol += (int) rootObject.getCol(); - IOUtilFunctions.closeSilently(bis); + rootObject.close(); } // allocate target matrix block based on given size; return createOutputMatrixBlock(nrow, ncol, nrow, estnnz, true, true); } + + private static int getHdf5ReadInt(String key, int defaultValue) { + String value = System.getProperty(key); + if(value == null) + return defaultValue; + try { + long parsed = Long.parseLong(value.trim()); + if(parsed <= 0 || parsed > Integer.MAX_VALUE) + return defaultValue; + return (int) parsed; + } + catch(NumberFormatException ex) { + return defaultValue; + } + } + + private static boolean getHdf5ReadBoolean(String key, boolean defaultValue) { + String value = System.getProperty(key); + if(value == null) + return defaultValue; + return Boolean.parseBoolean(value.trim()); + } + + static java.io.File getLocalFile(Path path) { + try { + return new java.io.File(path.toUri()); + } + catch(IllegalArgumentException ex) { + return new java.io.File(path.toString()); + } + } + + private static ByteBuffer sliceBuffer(ByteBuffer source, int offset, int length) { + ByteBuffer dup = source.duplicate(); + dup.position(offset); + dup.limit(offset + length); + return dup.slice(); + } + + static boolean isLocalFileSystem(FileSystem fs) { + if(fs instanceof LocalFileSystem || fs instanceof RawLocalFileSystem) + return true; + String scheme = fs.getScheme(); + return scheme != null && scheme.equalsIgnoreCase("file"); + } + + static H5ByteReader createByteReader(Path path, FileSystem fs) throws IOException { + long fileLength = fs.getFileStatus(path).getLen(); + String sourceId = path.toString(); + if(isLocalFileSystem(fs)) { + FileInputStream fis = new FileInputStream(getLocalFile(path)); + FileChannel channel = fis.getChannel(); + long length = channel.size(); + System.out.println("[HDF5] Using FileChannel-backed reader for " + sourceId + " (size=" + length + ")"); + if(HDF5_READ_USE_MMAP && length > 0) { + return new MappedH5ByteReader(channel, length, HDF5_READ_MAP_BYTES); + } + H5ByteReader base = new FileChannelByteReader(channel); + if(length > 0 && length <= Integer.MAX_VALUE) { + return new BufferedH5ByteReader(base, length, HDF5_READ_BUFFER_BYTES); + } + return base; + } + FSDataInputStream fsin = fs.open(path); + return createByteReader(fsin, sourceId, fileLength); + } + + private static final class FsDataInputStreamByteReader implements H5ByteReader { + private final FSDataInputStream input; + + FsDataInputStreamByteReader(FSDataInputStream input) { + this.input = input; + } + + @Override + public ByteBuffer read(long offset, int length) throws IOException { + byte[] buffer = new byte[length]; + input.readFully(offset, buffer, 0, length); + return ByteBuffer.wrap(buffer); + } + + @Override + public ByteBuffer read(long offset, int length, ByteBuffer reuse) throws IOException { + if(reuse == null || reuse.capacity() < length || !reuse.hasArray()) { + return read(offset, length); + } + byte[] buffer = reuse.array(); + int baseOffset = reuse.arrayOffset(); + input.readFully(offset, buffer, baseOffset, length); + reuse.position(baseOffset); + reuse.limit(baseOffset + length); + if(baseOffset == 0) { + return reuse; + } + return reuse.slice(); + } + + @Override + public void close() throws IOException { + input.close(); + } + } + + private static final class BufferedH5ByteReader implements H5ByteReader { + private final H5ByteReader base; + private final long length; + private final int windowSize; + private long windowStart = -1; + private int windowLength; + private ByteBuffer window; + private ByteBuffer windowStorage; + + BufferedH5ByteReader(H5ByteReader base, long length, int windowSize) { + this.base = base; + this.length = length; + this.windowSize = windowSize; + } + + @Override + public ByteBuffer read(long offset, int length) throws IOException { + if(length <= 0 || length > windowSize) { + return base.read(offset, length); + } + if(this.length > 0 && offset + length > this.length) { + return base.read(offset, length); + } + if(window != null && offset >= windowStart && offset + length <= windowStart + windowLength) { + return sliceBuffer(window, (int) (offset - windowStart), length); + } + int readSize = windowSize; + if(this.length > 0) { + long remaining = this.length - offset; + if(remaining > 0) + readSize = (int) Math.min(readSize, remaining); + } + if(readSize < length) { + readSize = length; + } + if(windowStorage == null || windowStorage.capacity() < readSize) { + windowStorage = ByteBuffer.allocate(windowSize); + } + window = base.read(offset, readSize, windowStorage); + windowStart = offset; + windowLength = window.remaining(); + return sliceBuffer(window, 0, length); + } + + @Override + public void close() throws IOException { + base.close(); + } + } + + private static final class FileChannelByteReader implements H5ByteReader { + private final FileChannel channel; + + FileChannelByteReader(FileChannel channel) { + this.channel = channel; + } + + @Override + public ByteBuffer read(long offset, int length) throws IOException { + ByteBuffer buffer = ByteBuffer.allocate(length); + long pos = offset; + while(buffer.hasRemaining()) { + int read = channel.read(buffer, pos); + if(read < 0) + throw new IOException("Unexpected EOF while reading HDF5 data at offset " + offset); + pos += read; + } + buffer.flip(); + return buffer; + } + + @Override + public ByteBuffer read(long offset, int length, ByteBuffer reuse) throws IOException { + if(reuse == null || reuse.capacity() < length) { + return read(offset, length); + } + reuse.clear(); + reuse.limit(length); + long pos = offset; + while(reuse.hasRemaining()) { + int read = channel.read(reuse, pos); + if(read < 0) + throw new IOException("Unexpected EOF while reading HDF5 data at offset " + offset); + pos += read; + } + reuse.flip(); + return reuse; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } + + private static final class MappedH5ByteReader implements H5ByteReader { + private final FileChannel channel; + private final long length; + private final int windowSize; + private long windowStart = -1; + private int windowLength; + private MappedByteBuffer window; + + MappedH5ByteReader(FileChannel channel, long length, int windowSize) { + this.channel = channel; + this.length = length; + this.windowSize = windowSize; + } + + @Override + public ByteBuffer read(long offset, int length) throws IOException { + if(length <= 0) + return ByteBuffer.allocate(0); + if(this.length > 0 && offset + length > this.length) { + throw new IOException("Attempted to read past EOF at offset " + offset + " length " + length); + } + if(length > windowSize) { + MappedByteBuffer mapped = channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + return mapped; + } + if(window != null && offset >= windowStart && offset + length <= windowStart + windowLength) { + return sliceBuffer(window, (int) (offset - windowStart), length); + } + int readSize = windowSize; + if(this.length > 0) { + long remaining = this.length - offset; + if(remaining > 0) + readSize = (int) Math.min(readSize, remaining); + } + if(readSize < length) { + readSize = length; + } + window = channel.map(FileChannel.MapMode.READ_ONLY, offset, readSize); + windowStart = offset; + windowLength = readSize; + return sliceBuffer(window, 0, length); + } + + @Override + public void close() throws IOException { + channel.close(); + } + } + + private static final class ByteArrayH5ByteReader implements H5ByteReader { + private final byte[] data; + + ByteArrayH5ByteReader(byte[] data) { + this.data = data; + } + + @Override + public ByteBuffer read(long offset, int length) throws IOException { + if(offset < 0 || offset + length > data.length) { + throw new IOException("Attempted to read outside cached buffer (offset=" + offset + ", len=" + length + + ", size=" + data.length + ")"); + } + if(offset > Integer.MAX_VALUE) { + throw new IOException("Offset exceeds byte array capacity: " + offset); + } + return ByteBuffer.wrap(data, (int) offset, length).slice(); + } + + @Override + public void close() { + // nothing to close + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderHDF5Parallel.java b/src/main/java/org/apache/sysds/runtime/io/ReaderHDF5Parallel.java index 658eb538265..36e2a3e1bfc 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderHDF5Parallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderHDF5Parallel.java @@ -19,9 +19,13 @@ package org.apache.sysds.runtime.io; -import java.io.BufferedInputStream; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -32,13 +36,21 @@ import org.apache.hadoop.mapred.FileInputFormat; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.TextInputFormat; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.io.hdf5.H5Constants; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP; +import org.apache.sysds.runtime.data.DenseBlockLFP64DEDUP; +import org.apache.sysds.runtime.io.hdf5.H5ByteReader; +import org.apache.sysds.runtime.io.hdf5.H5ContiguousDataset; +import org.apache.sysds.runtime.io.hdf5.H5RootObject; +import org.apache.sysds.runtime.io.hdf5.H5; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.runtime.util.UtilFunctions; public class ReaderHDF5Parallel extends ReaderHDF5 { @@ -71,12 +83,19 @@ public MatrixBlock readMatrixFromHDFS(String fname, long rlen, long clen, int bl ArrayList files = new ArrayList<>(); files.add(path); MatrixBlock src = computeHDF5Size(files, fs, _props.getDatasetName(), estnnz); + if(ReaderHDF5.isLocalFileSystem(fs) && !fs.getFileStatus(path).isDirectory()) { + Long nnz = readMatrixFromHDF5ParallelLocal(path, fs, src, 0, src.getNumRows(), + src.getNumColumns(), blen, _props.getDatasetName()); + if(nnz != null) { + src.setNonZeros(nnz); + return src; + } + } int numParts = Math.min(files.size(), _numThreads); //create and execute tasks ExecutorService pool = CommonThreadPool.get(_numThreads); try { - int bufferSize = (src.getNumColumns() * src.getNumRows()) * 8 + H5Constants.STATIC_HEADER_SIZE; ArrayList tasks = new ArrayList<>(); rlen = src.getNumRows(); int blklen = (int) Math.ceil((double) rlen / numParts); @@ -85,10 +104,7 @@ public MatrixBlock readMatrixFromHDFS(String fname, long rlen, long clen, int bl int ru = (int) Math.min((i + 1) * blklen, rlen); Path newPath = HDFSTool.isDirectory(fs, path) ? new Path(path, IOUtilFunctions.getPartFileName(i)) : path; - BufferedInputStream bis = new BufferedInputStream(fs.open(newPath), bufferSize); - - //BufferedInputStream bis, String datasetName, MatrixBlock src, MutableInt rl, int ru - tasks.add(new ReadHDF5Task(bis, _props.getDatasetName(), src, rl, ru, clen, blklen)); + tasks.add(new ReadHDF5Task(fs, newPath, _props.getDatasetName(), src, rl, ru, clen, blklen)); } long nnz = 0; @@ -113,9 +129,208 @@ public MatrixBlock readMatrixFromInputStream(InputStream is, long rlen, long cle return new ReaderHDF5(_props).readMatrixFromInputStream(is, rlen, clen, blen, estnnz); } + private static Long readMatrixFromHDF5ParallelLocal(Path path, FileSystem fs, MatrixBlock dest, + int rl, long ru, long clen, int blen, String datasetName) throws IOException + { + H5RootObject rootObject = null; + long dataAddress; + long elemSize; + long rows; + long cols; + try { + H5ByteReader metaReader = createByteReader(path, fs); + rootObject = H5.H5Fopen(metaReader); + H5ContiguousDataset dataset = H5.H5Dopen(rootObject, datasetName); + if(dataset.isRankGt2() && !dataset.isRowContiguous()) { + rootObject.close(); + return null; + } + elemSize = dataset.getElementSize(); + if(elemSize != 8) { + rootObject.close(); + return null; + } + dataAddress = dataset.getDataAddress(); + rows = rootObject.getRow(); + cols = rootObject.getCol(); + long rowByteSize = dataset.getRowByteSize(); + if(rowByteSize <= 0) { + rootObject.close(); + return null; + } + rootObject.close(); + rootObject = null; + } + finally { + if(rootObject != null) + rootObject.close(); + } + + if(dest.isInSparseFormat()) { + if(HDF5_FORCE_DENSE) { + dest.allocateDenseBlock(true); + if(HDF5_READ_TRACE) + System.out.println("[HDF5] Forcing dense output for parallel mmap dataset=" + datasetName); + } + else { + return null; + } + } + DenseBlock denseBlock = dest.getDenseBlock(); + boolean fastDense = denseBlock.isNumeric(ValueType.FP64) + && !(denseBlock instanceof DenseBlockFP64DEDUP) + && !(denseBlock instanceof DenseBlockLFP64DEDUP); + boolean contiguousDense = fastDense && denseBlock.isContiguous(); + if(!fastDense) { + return null; + } + + if(cols > Integer.MAX_VALUE || rows > Integer.MAX_VALUE) { + return null; + } + int ncol = (int) cols; + long rowBytesLong = elemSize * ncol; + if(rowBytesLong <= 0 || rowBytesLong > Integer.MAX_VALUE) { + return null; + } + long totalRowsLong = ru - rl; + if(totalRowsLong <= 0 || totalRowsLong > Integer.MAX_VALUE) { + return null; + } + long totalBytes = totalRowsLong * rowBytesLong; + if(totalBytes < HDF5_READ_PARALLEL_MIN_BYTES || HDF5_READ_PARALLEL_THREADS <= 1) { + return null; + } + + int numThreads = Math.min(HDF5_READ_PARALLEL_THREADS, (int) totalRowsLong); + int rowsPerTask = (int) Math.ceil((double) totalRowsLong / numThreads); + double[] destAll = contiguousDense ? denseBlock.values(0) : null; + int destBase = contiguousDense ? denseBlock.pos(rl) : 0; + int rowBytes = (int) rowBytesLong; + int windowBytes = HDF5_READ_MAP_BYTES; + boolean skipNnz = HDF5_SKIP_NNZ; + if(HDF5_READ_TRACE) { + System.out.println("[HDF5] Parallel mmap read enabled dataset=" + datasetName + " rows=" + totalRowsLong + + " cols=" + cols + " threads=" + numThreads + " windowBytes=" + windowBytes + " skipNnz=" + skipNnz); + } + + java.io.File localFile = getLocalFile(path); + ExecutorService pool = CommonThreadPool.get(numThreads); + ArrayList> tasks = new ArrayList<>(); + for(int rowOffset = 0; rowOffset < totalRowsLong; rowOffset += rowsPerTask) { + int rowsToRead = (int) Math.min(rowsPerTask, totalRowsLong - rowOffset); + int destOffset = contiguousDense ? destBase + rowOffset * ncol : 0; + int startRow = rl + rowOffset; + long fileOffset = dataAddress + ((long) (rl + rowOffset) * rowBytes); + tasks.add(new H5ParallelReadTask(localFile, fileOffset, rowBytes, rowsToRead, ncol, destAll, + destOffset, denseBlock, startRow, windowBytes, skipNnz)); + } + + long lnnz = 0; + try { + for(Future task : pool.invokeAll(tasks)) + lnnz += task.get(); + } + catch(Exception e) { + throw new IOException("Failed parallel read of HDF5 input.", e); + } + finally { + pool.shutdown(); + } + + if(skipNnz) { + lnnz = Math.multiplyExact(totalRowsLong, clen); + } + return lnnz; + } + + private static final class H5ParallelReadTask implements Callable { + private static final int ELEM_BYTES = 8; + private final java.io.File file; + private final long fileOffset; + private final int rowBytes; + private final int rows; + private final int ncol; + private final double[] dest; + private final int destOffset; + private final DenseBlock denseBlock; + private final int startRow; + private final int windowBytes; + private final boolean skipNnz; + + H5ParallelReadTask(java.io.File file, long fileOffset, int rowBytes, int rows, int ncol, double[] dest, + int destOffset, DenseBlock denseBlock, int startRow, int windowBytes, boolean skipNnz) + { + this.file = file; + this.fileOffset = fileOffset; + this.rowBytes = rowBytes; + this.rows = rows; + this.ncol = ncol; + this.dest = dest; + this.destOffset = destOffset; + this.denseBlock = denseBlock; + this.startRow = startRow; + this.windowBytes = windowBytes; + this.skipNnz = skipNnz; + } + + @Override + public Long call() throws IOException { + long nnz = 0; + long remaining = (long) rows * rowBytes; + long offset = fileOffset; + int destIndex = destOffset; + int rowCursor = startRow; + int window = Math.max(windowBytes, ELEM_BYTES); + try(FileInputStream fis = new FileInputStream(file); + FileChannel channel = fis.getChannel()) { + while(remaining > 0) { + int mapBytes; + if(dest != null) { + mapBytes = (int) Math.min(window, remaining); + mapBytes -= mapBytes % ELEM_BYTES; + if(mapBytes == 0) + mapBytes = (int) Math.min(remaining, ELEM_BYTES); + } + else { + int rowsInMap = (int) Math.min(remaining / rowBytes, window / rowBytes); + if(rowsInMap <= 0) + rowsInMap = 1; + mapBytes = rowsInMap * rowBytes; + } + MappedByteBuffer map = channel.map(FileChannel.MapMode.READ_ONLY, offset, mapBytes); + map.order(ByteOrder.LITTLE_ENDIAN); + DoubleBuffer db = map.asDoubleBuffer(); + int doubles = mapBytes / ELEM_BYTES; + if(dest != null) { + db.get(dest, destIndex, doubles); + if(!skipNnz) + nnz += UtilFunctions.computeNnz(dest, destIndex, doubles); + destIndex += doubles; + } + else { + int rowsRead = mapBytes / rowBytes; + for(int r = 0; r < rowsRead; r++) { + double[] rowVals = denseBlock.values(rowCursor + r); + int rowPos = denseBlock.pos(rowCursor + r); + db.get(rowVals, rowPos, ncol); + if(!skipNnz) + nnz += UtilFunctions.computeNnz(rowVals, rowPos, ncol); + } + rowCursor += rowsRead; + } + offset += mapBytes; + remaining -= mapBytes; + } + } + return nnz; + } + } + private static class ReadHDF5Task implements Callable { - private final BufferedInputStream _bis; + private final FileSystem _fs; + private final Path _path; private final String _datasetName; private final MatrixBlock _src; private final int _rl; @@ -123,10 +338,11 @@ private static class ReadHDF5Task implements Callable { private final long _clen; private final int _blen; - public ReadHDF5Task(BufferedInputStream bis, String datasetName, MatrixBlock src, + public ReadHDF5Task(FileSystem fs, Path path, String datasetName, MatrixBlock src, int rl, int ru, long clen, int blen) { - _bis = bis; + _fs = fs; + _path = path; _datasetName = datasetName; _src = src; _rl = rl; @@ -137,7 +353,9 @@ public ReadHDF5Task(BufferedInputStream bis, String datasetName, MatrixBlock src @Override public Long call() throws IOException { - return readMatrixFromHDF5(_bis, _datasetName, _src, _rl, _ru, _clen, _blen); + try(H5ByteReader byteReader = ReaderHDF5.createByteReader(_path, _fs)) { + return readMatrixFromHDF5(byteReader, _datasetName, _src, _rl, _ru, _clen, _blen); + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5.java index 0ab909f0a3b..0f640490ed6 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5.java @@ -19,10 +19,12 @@ package org.apache.sysds.runtime.io.hdf5; -import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import org.apache.sysds.runtime.io.hdf5.message.H5SymbolTableMessage; @@ -35,16 +37,15 @@ public class H5 { // 4. Write/Read // 5. Close File - public static H5RootObject H5Fopen(BufferedInputStream bis) { + public static H5RootObject H5Fopen(H5ByteReader reader) { H5RootObject rootObject = new H5RootObject(); - bis.mark(0); try { // Find out if the file is a HDF5 file int maxSignatureLength = 2048; boolean validSignature = false; long offset; for(offset = 0; offset < maxSignatureLength; offset = nextOffset(offset)) { - validSignature = H5Superblock.verifySignature(bis, offset); + validSignature = H5Superblock.verifySignature(reader, offset); if(validSignature) { break; } @@ -52,9 +53,9 @@ public static H5RootObject H5Fopen(BufferedInputStream bis) { if(!validSignature) { throw new H5RuntimeException("No valid HDF5 signature found"); } - rootObject.setBufferedInputStream(bis); + rootObject.setByteReader(reader); - final H5Superblock superblock = new H5Superblock(bis, offset); + final H5Superblock superblock = new H5Superblock(reader, offset); rootObject.setSuperblock(superblock); } catch(Exception exception) { @@ -113,38 +114,79 @@ public static H5RootObject H5Screate(BufferedOutputStream bos, long row, long co // Open a Data Space public static H5ContiguousDataset H5Dopen(H5RootObject rootObject, String datasetName) { try { - H5SymbolTableEntry symbolTableEntry = new H5SymbolTableEntry(rootObject, + List datasetPath = normalizeDatasetPath(datasetName); + H5SymbolTableEntry currentEntry = new H5SymbolTableEntry(rootObject, rootObject.getSuperblock().rootGroupSymbolTableAddress - rootObject.getSuperblock().baseAddressByte); + rootObject.setDatasetName(datasetName); - H5ObjectHeader objectHeader = new H5ObjectHeader(rootObject, symbolTableEntry.getObjectHeaderAddress()); - - final H5SymbolTableMessage stm = (H5SymbolTableMessage) objectHeader.getMessages().get(0); - final H5BTree rootBTreeNode = new H5BTree(rootObject, stm.getbTreeAddress()); - final H5LocalHeap rootNameHeap = new H5LocalHeap(rootObject, stm.getLocalHeapAddress()); - final ByteBuffer nameBuffer = rootNameHeap.getDataBuffer(); - final List childAddresses = rootBTreeNode.getChildAddresses(); - - long child = childAddresses.get(0); + StringBuilder traversedPath = new StringBuilder("/"); + for(String segment : datasetPath) { + currentEntry = descendIntoChild(rootObject, currentEntry, segment, traversedPath.toString()); + if(traversedPath.length() > 1) + traversedPath.append('/'); + traversedPath.append(segment); + } - H5GroupSymbolTableNode groupSTE = new H5GroupSymbolTableNode(rootObject, child); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Opening dataset '" + datasetName + "' resolved to object header @ " + + currentEntry.getObjectHeaderAddress()); + } - symbolTableEntry = groupSTE.getSymbolTableEntries()[0]; + final H5ObjectHeader header = new H5ObjectHeader(rootObject, currentEntry.getObjectHeaderAddress()); + return new H5ContiguousDataset(rootObject, header); - nameBuffer.position(symbolTableEntry.getLinkNameOffset()); - String childName = Utils.readUntilNull(nameBuffer); + } + catch(Exception exception) { + throw new H5RuntimeException(exception); + } + } - if(!childName.equals(datasetName)) { - throw new H5RuntimeException("The requested dataset '" + datasetName + "' differs from available '"+childName+"'."); + private static H5SymbolTableEntry descendIntoChild(H5RootObject rootObject, H5SymbolTableEntry parentEntry, + String childSegment, String currentPath) { + H5ObjectHeader objectHeader = new H5ObjectHeader(rootObject, parentEntry.getObjectHeaderAddress()); + H5SymbolTableMessage symbolTableMessage = objectHeader.getMessageOfType(H5SymbolTableMessage.class); + List children = readSymbolTableEntries(rootObject, symbolTableMessage); + H5LocalHeap heap = new H5LocalHeap(rootObject, symbolTableMessage.getLocalHeapAddress()); + ByteBuffer nameBuffer = heap.getDataBuffer(); + List availableNames = new ArrayList<>(); + for(H5SymbolTableEntry child : children) { + nameBuffer.position(child.getLinkNameOffset()); + String candidateName = Utils.readUntilNull(nameBuffer); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Visit '" + currentPath + "' child -> '" + candidateName + "'"); + } + availableNames.add(candidateName); + if(candidateName.equals(childSegment)) { + return child; } + } + throw new H5RuntimeException("Dataset path segment '" + childSegment + "' not found under '" + currentPath + + "'. Available entries: " + availableNames); + } - final H5ObjectHeader header = new H5ObjectHeader(rootObject, symbolTableEntry.getObjectHeaderAddress()); - final H5ContiguousDataset contiguousDataset = new H5ContiguousDataset(rootObject, header); - return contiguousDataset; + private static List readSymbolTableEntries(H5RootObject rootObject, + H5SymbolTableMessage symbolTableMessage) { + H5BTree btree = new H5BTree(rootObject, symbolTableMessage.getbTreeAddress()); + List entries = new ArrayList<>(); + for(Long childAddress : btree.getChildAddresses()) { + H5GroupSymbolTableNode groupNode = new H5GroupSymbolTableNode(rootObject, childAddress); + entries.addAll(Arrays.asList(groupNode.getSymbolTableEntries())); + } + return entries; + } + private static List normalizeDatasetPath(String datasetName) { + if(datasetName == null) { + throw new H5RuntimeException("Dataset name cannot be null"); } - catch(Exception exception) { - throw new H5RuntimeException(exception); + List tokens = Arrays.stream(datasetName.split("/")) + .map(String::trim) + .filter(token -> !token.isEmpty()) + .collect(Collectors.toList()); + if(tokens.isEmpty()) { + throw new H5RuntimeException("Dataset name '" + datasetName + "' is invalid."); } + return tokens; } // Create Dataset @@ -196,14 +238,12 @@ public static void H5Dwrite(H5RootObject rootObject, double[][] data) { public static void H5Dread(H5RootObject rootObject, H5ContiguousDataset dataset, double[][] data) { for(int i = 0; i < rootObject.getRow(); i++) { - ByteBuffer buffer = dataset.getDataBuffer(i); - dataset.getDataType().getDoubleDataType().fillData(buffer, data[i]); + dataset.readRowDoubles(i, data[i], 0); } } public static void H5Dread(H5ContiguousDataset dataset, int row, double[] data) { - ByteBuffer buffer = dataset.getDataBuffer(row); - dataset.getDataType().getDoubleDataType().fillData(buffer, data); + dataset.readRowDoubles(row, data, 0); } } diff --git a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test1.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ByteReader.java similarity index 66% rename from src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test1.java rename to src/main/java/org/apache/sysds/runtime/io/hdf5/H5ByteReader.java index b0fff7a6391..5421e5f3b0f 100644 --- a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test1.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ByteReader.java @@ -17,22 +17,20 @@ * under the License. */ -package org.apache.sysds.test.functions.io.hdf5; +package org.apache.sysds.runtime.io.hdf5; -public class ReadHDF5Test1 extends ReadHDF5Test { +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; - private final static String TEST_NAME = "ReadHDF5Test"; - public final static String TEST_CLASS_DIR = TEST_DIR + ReadHDF5Test1.class.getSimpleName() + "/"; +public interface H5ByteReader extends Closeable { - protected String getTestName() { - return TEST_NAME; - } + ByteBuffer read(long offset, int length) throws IOException; - protected String getTestClassDir() { - return TEST_CLASS_DIR; + default ByteBuffer read(long offset, int length, ByteBuffer reuse) throws IOException { + return read(offset, length); } - protected int getId() { - return 1; - } + @Override + void close() throws IOException; } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Constants.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Constants.java index 9d2414bec84..f80690454d8 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Constants.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Constants.java @@ -30,4 +30,6 @@ public final class H5Constants { public static final int DATA_LAYOUT_MESSAGE = 8; public static final int SYMBOL_TABLE_MESSAGE = 17; public static final int OBJECT_MODIFICATION_TIME_MESSAGE = 18; + public static final int FILTER_PIPELINE_MESSAGE = 11; + public static final int ATTRIBUTE_MESSAGE = 12; } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ContiguousDataset.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ContiguousDataset.java index 3ae6761e864..b132ea6a5aa 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ContiguousDataset.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5ContiguousDataset.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.io.hdf5.message.H5DataTypeMessage; import java.nio.ByteBuffer; +import java.util.Arrays; import static java.nio.ByteOrder.LITTLE_ENDIAN; @@ -35,29 +36,235 @@ public class H5ContiguousDataset { private final H5DataTypeMessage dataTypeMessage; @SuppressWarnings("unused") private final H5DataSpaceMessage dataSpaceMessage; + private final boolean rankGt2; + private final long elemSize; + private final long dataSize; + private ByteBuffer fullData; + private boolean fullDataLoaded = false; + private final int[] dims; + private final int[] fileDims; + private final long[] fileStrides; + private final int[] axisPermutation; + private final long rowByteStride; + private final long rowByteSize; + private long[] colOffsets; public H5ContiguousDataset(H5RootObject rootObject, H5ObjectHeader objectHeader) { this.rootObject = rootObject; this.dataLayoutMessage = objectHeader.getMessageOfType(H5DataLayoutMessage.class); + if(this.dataLayoutMessage.getLayoutClass() != H5DataLayoutMessage.LAYOUT_CLASS_CONTIGUOUS) { + throw new H5RuntimeException("Unsupported data layout class: " + + this.dataLayoutMessage.getLayoutClass() + " (only contiguous datasets are supported)."); + } this.dataTypeMessage = objectHeader.getMessageOfType(H5DataTypeMessage.class); this.dataSpaceMessage = objectHeader.getMessageOfType(H5DataSpaceMessage.class); + + this.dims = rootObject.getLogicalDimensions(); + this.fileDims = rootObject.getRawDimensions() != null ? rootObject.getRawDimensions() : this.dims; + this.axisPermutation = normalizePermutation(rootObject.getAxisPermutation(), this.dims); + this.rankGt2 = this.dims != null && this.dims.length > 2; + this.elemSize = this.dataTypeMessage.getDoubleDataType().getSize(); + this.dataSize = this.dataLayoutMessage.getSize(); + this.fileStrides = computeStridesRowMajor(this.fileDims); + this.rowByteStride = (fileStrides.length == 0) ? 0 : fileStrides[axisPermutation[0]] * elemSize; + if(H5RootObject.HDF5_DEBUG && rankGt2) { + System.out.println("[HDF5] dataset=" + rootObject.getDatasetName() + " logicalDims=" + + Arrays.toString(dims) + " fileDims=" + Arrays.toString(fileDims) + " axisPerm=" + + Arrays.toString(axisPermutation) + " fileStrides=" + Arrays.toString(fileStrides)); + } + + this.rowByteSize = rootObject.getCol() * elemSize; } public ByteBuffer getDataBuffer(int row) { + return getDataBuffer(row, 1); + } + + public ByteBuffer getDataBuffer(int row, int rowCount) { try { - long rowPos = row * rootObject.getCol()*this.dataTypeMessage.getDoubleDataType().getSize(); - ByteBuffer data = rootObject.readBufferFromAddressNoOrder(dataLayoutMessage.getAddress() + rowPos, - (int) (rootObject.getCol() * this.dataTypeMessage.getDoubleDataType().getSize())); - data.order(LITTLE_ENDIAN); + long cols = rootObject.getCol(); + long rowBytes = cols * elemSize; + if(rowBytes > Integer.MAX_VALUE) { + throw new H5RuntimeException("Row byte size exceeds buffer capacity: " + rowBytes); + } + if(rowCount <= 0) { + throw new H5RuntimeException("Row count must be positive, got " + rowCount); + } + long readLengthLong = rowBytes * rowCount; + if(readLengthLong > Integer.MAX_VALUE) { + throw new H5RuntimeException("Requested read exceeds buffer capacity: " + readLengthLong); + } + int readLength = (int) readLengthLong; - return data; + if(rankGt2) { + if(isRowContiguous()) { + long rowPos = row * rowByteSize; + long layoutAddress = dataLayoutMessage.getAddress(); + long dataAddress = layoutAddress + rowPos; + ByteBuffer data = rootObject.readBufferFromAddressNoOrder(dataAddress, readLength); + data.order(LITTLE_ENDIAN); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] getDataBuffer (rank>2 contiguous) dataset=" + rootObject.getDatasetName() + + " row=" + row + " rowCount=" + rowCount + " readLength=" + readLength); + } + return data; + } + if(rowCount != 1) { + throw new H5RuntimeException("Row block reads are not supported for non-contiguous rank>2 datasets."); + } + if(!fullDataLoaded) { + fullData = rootObject.readBufferFromAddressNoOrder(dataLayoutMessage.getAddress(), + (int) dataSize); + fullData.order(LITTLE_ENDIAN); + fullDataLoaded = true; + } + if(colOffsets == null) { + colOffsets = new long[(int) cols]; + for(int c = 0; c < cols; c++) { + colOffsets[c] = computeByteOffset(0, c); + } + } + ByteBuffer rowBuf = ByteBuffer.allocate(readLength).order(LITTLE_ENDIAN); + if(H5RootObject.HDF5_DEBUG && row == 0) { + long debugCols = Math.min(cols, 5); + for(long c = 0; c < debugCols; c++) { + long byteOff = rowByteStride * row + colOffsets[(int) c]; + double v = fullData.getDouble((int) byteOff); + System.out.println("[HDF5] map(row=" + row + ", col=" + c + ") -> byteOff=" + byteOff + + " val=" + v); + } + } + for(int c = 0; c < cols; c++) { + long byteOff = rowByteStride * row + colOffsets[c]; + double v = fullData.getDouble((int) byteOff); + if(H5RootObject.HDF5_DEBUG && row == 3 && c == 3) { + System.out.println("[HDF5] sample(row=" + row + ", col=" + c + ") byteOff=" + byteOff + + " val=" + v); + } + rowBuf.putDouble(v); + } + rowBuf.rewind(); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] getDataBuffer (rank>2) dataset=" + rootObject.getDatasetName() + " row=" + row + + " cols=" + cols + " elemSize=" + elemSize + " dataSize=" + dataSize); + } + return rowBuf; + } + else { + long rowPos = row * rowBytes; + long layoutAddress = dataLayoutMessage.getAddress(); + // layoutAddress is already an absolute file offset for the contiguous data block. + long dataAddress = layoutAddress + rowPos; + ByteBuffer data = rootObject.readBufferFromAddressNoOrder(dataAddress, readLength); + data.order(LITTLE_ENDIAN); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] getDataBuffer dataset=" + rootObject.getDatasetName() + " row=" + row + + " layoutAddr=" + layoutAddress + " rowPos=" + rowPos + " readLength=" + readLength + + " col=" + cols + " rowCount=" + rowCount); + } + return data; + } } catch(Exception e) { throw new H5RuntimeException("Failed to map data buffer for dataset", e); } } + + public void readRowDoubles(int row, double[] dest, int destPos) { + long cols = rootObject.getCol(); + if(cols > Integer.MAX_VALUE) { + throw new H5RuntimeException("Column count exceeds buffer capacity: " + cols); + } + int ncol = (int) cols; + if(rankGt2) { + if(isRowContiguous()) { + ByteBuffer data = getDataBuffer(row, 1); + data.order(LITTLE_ENDIAN); + data.asDoubleBuffer().get(dest, destPos, ncol); + return; + } + if(!fullDataLoaded) { + fullData = rootObject.readBufferFromAddressNoOrder(dataLayoutMessage.getAddress(), (int) dataSize); + fullData.order(LITTLE_ENDIAN); + fullDataLoaded = true; + } + if(colOffsets == null) { + colOffsets = new long[ncol]; + for(int c = 0; c < ncol; c++) { + colOffsets[c] = computeByteOffset(0, c); + } + } + long rowBase = rowByteStride * row; + for(int c = 0; c < ncol; c++) { + dest[destPos + c] = fullData.getDouble((int) (rowBase + colOffsets[c])); + } + return; + } + ByteBuffer data = getDataBuffer(row); + data.order(LITTLE_ENDIAN); + data.asDoubleBuffer().get(dest, destPos, ncol); + } + + private static long[] computeStridesRowMajor(int[] dims) { + if(dims == null || dims.length == 0) + return new long[0]; + long[] strides = new long[dims.length]; + strides[dims.length - 1] = 1; + for(int i = dims.length - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + return strides; + } + + private long computeByteOffset(long row, long col) { + long linear = row * fileStrides[axisPermutation[0]]; + long rem = col; + for(int axis = dims.length - 1; axis >= 1; axis--) { + int dim = dims[axis]; + long idx = (dim == 0) ? 0 : rem % dim; + rem = (dim == 0) ? 0 : rem / dim; + linear += idx * fileStrides[axisPermutation[axis]]; + } + return linear * elemSize; + } + + private static int[] normalizePermutation(int[] permutation, int[] dims) { + int rank = (dims == null) ? 0 : dims.length; + if(permutation == null || permutation.length != rank) { + int[] identity = new int[rank]; + for(int i = 0; i < rank; i++) + identity[i] = i; + return identity; + } + return permutation; + } + public H5DataTypeMessage getDataType() { return dataTypeMessage; } + + public long getDataAddress() { + return dataLayoutMessage.getAddress(); + } + + public long getDataSize() { + return dataSize; + } + + public long getElementSize() { + return elemSize; + } + + public boolean isRankGt2() { + return rankGt2; + } + + public long getRowByteSize() { + return rowByteSize; + } + + public boolean isRowContiguous() { + return rowByteStride == rowByteSize; + } } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5RootObject.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5RootObject.java index ebfb719e0be..823359660fb 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5RootObject.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5RootObject.java @@ -19,22 +19,24 @@ package org.apache.sysds.runtime.io.hdf5; -import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import static java.nio.ByteOrder.LITTLE_ENDIAN; public class H5RootObject { - protected BufferedInputStream bufferedInputStream; + protected H5ByteReader byteReader; protected BufferedOutputStream bufferedOutputStream; protected H5Superblock superblock; protected int rank; protected long row; protected long col; - protected int[] dimensions; + protected int[] logicalDimensions; + protected int[] rawDimensions; + protected int[] axisPermutation; protected long maxRow; protected long maxCol; protected int[] maxSizes; @@ -50,46 +52,47 @@ public class H5RootObject { protected byte groupSymbolTableNodeVersion = 1; protected byte dataLayoutClass = 1; + public static final boolean HDF5_DEBUG = Boolean.getBoolean("sysds.hdf5.debug"); public ByteBuffer readBufferFromAddress(long address, int length) { - ByteBuffer bb = ByteBuffer.allocate(length); try { - byte[] b = new byte[length]; - bufferedInputStream.reset(); - bufferedInputStream.skip(address); - bufferedInputStream.read(b); - bb.put(b); + ByteBuffer bb = byteReader.read(address, length); + bb.order(LITTLE_ENDIAN); + bb.rewind(); + return bb; } catch(IOException e) { throw new H5RuntimeException(e); } - bb.order(LITTLE_ENDIAN); - bb.rewind(); - return bb; } public ByteBuffer readBufferFromAddressNoOrder(long address, int length) { - ByteBuffer bb = ByteBuffer.allocate(length); try { - byte[] b = new byte[length]; - bufferedInputStream.reset(); - bufferedInputStream.skip(address); - bufferedInputStream.read(b); - bb.put(b); + ByteBuffer bb = byteReader.read(address, length); + bb.rewind(); + return bb; } catch(IOException e) { throw new H5RuntimeException(e); } - bb.rewind(); - return bb; } - public BufferedInputStream getBufferedInputStream() { - return bufferedInputStream; + public void setByteReader(H5ByteReader byteReader) { + this.byteReader = byteReader; } - public void setBufferedInputStream(BufferedInputStream bufferedInputStream) { - this.bufferedInputStream = bufferedInputStream; + public H5ByteReader getByteReader() { + return byteReader; + } + + public void close() { + try { + if(byteReader != null) + byteReader.close(); + } + catch(IOException e) { + throw new H5RuntimeException(e); + } } public BufferedOutputStream getBufferedOutputStream() { @@ -114,7 +117,8 @@ public long getRow() { public void setRow(long row) { this.row = row; - this.dimensions[0] = (int) row; + if(this.logicalDimensions != null && this.logicalDimensions.length > 0) + this.logicalDimensions[0] = (int) row; } public long getCol() { @@ -123,7 +127,8 @@ public long getCol() { public void setCol(long col) { this.col = col; - this.dimensions[1] = (int) col; + if(this.logicalDimensions != null && this.logicalDimensions.length > 1) + this.logicalDimensions[1] = (int) col; } public int getRank() { @@ -132,7 +137,7 @@ public int getRank() { public void setRank(int rank) { this.rank = rank; - this.dimensions = new int[rank]; + this.logicalDimensions = new int[rank]; this.maxSizes = new int[rank]; } @@ -142,7 +147,8 @@ public long getMaxRow() { public void setMaxRow(long maxRow) { this.maxRow = maxRow; - this.maxSizes[0] = (int) maxRow; + if(this.maxSizes != null && this.maxSizes.length > 0) + this.maxSizes[0] = (int) maxRow; } public long getMaxCol() { @@ -151,7 +157,8 @@ public long getMaxCol() { public void setMaxCol(long maxCol) { this.maxCol = maxCol; - this.maxSizes[1] = (int) maxCol; + if(this.maxSizes != null && this.maxSizes.length > 1) + this.maxSizes[1] = (int) maxCol; } public String getDatasetName() { @@ -163,13 +170,25 @@ public void setDatasetName(String datasetName) { } public int[] getDimensions() { - return dimensions; + return logicalDimensions; + } + + public int[] getLogicalDimensions() { + return logicalDimensions; } public int[] getMaxSizes() { return maxSizes; } + public int[] getRawDimensions() { + return rawDimensions; + } + + public int[] getAxisPermutation() { + return axisPermutation; + } + public byte getDataSpaceVersion() { return dataSpaceVersion; } @@ -179,15 +198,45 @@ public void setDataSpaceVersion(byte dataSpaceVersion) { } public void setDimensions(int[] dimensions) { - this.dimensions = dimensions; - this.row = dimensions[0]; - this.col = dimensions[1]; + this.rawDimensions = dimensions; + if(dimensions == null || dimensions.length == 0) { + this.logicalDimensions = dimensions; + this.axisPermutation = new int[0]; + this.row = 0; + this.col = 0; + return; + } + int[] logical = Arrays.copyOf(dimensions, dimensions.length); + int[] permutation = identityPermutation(dimensions.length); + this.logicalDimensions = logical; + this.axisPermutation = permutation; + this.row = logicalDimensions[0]; + this.col = flattenColumns(logicalDimensions); + if(HDF5_DEBUG) { + System.out.println("[HDF5] setDimensions rank=" + dimensions.length + " rawDims=" + + java.util.Arrays.toString(dimensions) + " logicalDims=" + java.util.Arrays.toString(logicalDimensions) + + " axisPerm=" + java.util.Arrays.toString(axisPermutation) + " => rows=" + row + " cols(flat)=" + col); + } + if(HDF5_DEBUG) { + System.out.println("[HDF5] setDimensions debug raw=" + java.util.Arrays.toString(dimensions) + + " logical=" + java.util.Arrays.toString(logicalDimensions) + " perm=" + + java.util.Arrays.toString(axisPermutation)); + } } public void setMaxSizes(int[] maxSizes) { this.maxSizes = maxSizes; + if(maxSizes == null || maxSizes.length == 0) { + this.maxRow = 0; + this.maxCol = 0; + return; + } this.maxRow = maxSizes[0]; - this.maxCol = maxSizes[1]; + this.maxCol = flattenColumns(maxSizes); + if(HDF5_DEBUG) { + System.out.println("[HDF5] setMaxSizes rank=" + maxSizes.length + " max=" + java.util.Arrays.toString(maxSizes) + + " => maxRows=" + maxRow + " maxCols(flat)=" + maxCol); + } } public byte getObjectHeaderVersion() { @@ -245,4 +294,23 @@ public byte getGroupSymbolTableNodeVersion() { public void setGroupSymbolTableNodeVersion(byte groupSymbolTableNodeVersion) { this.groupSymbolTableNodeVersion = groupSymbolTableNodeVersion; } + + private long flattenColumns(int[] dims) { + if(dims.length == 1) { + return 1; + } + long product = 1; + for(int i = 1; i < dims.length; i++) { + product = Math.multiplyExact(product, dims[i]); + } + return product; + } + + private static int[] identityPermutation(int rank) { + int[] perm = new int[rank]; + for(int i = 0; i < rank; i++) + perm[i] = i; + return perm; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Superblock.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Superblock.java index 78fa90edd63..e0c921703c4 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Superblock.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/H5Superblock.java @@ -20,8 +20,6 @@ package org.apache.sysds.runtime.io.hdf5; -import java.io.BufferedInputStream; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; @@ -48,47 +46,30 @@ public class H5Superblock { public H5Superblock() { } - static boolean verifySignature(BufferedInputStream bis, long offset) { - // Format Signature - byte[] signature = new byte[HDF5_FILE_SIGNATURE_LENGTH]; - + static boolean verifySignature(H5ByteReader reader, long offset) { try { - bis.reset(); - bis.skip(offset); - bis.read(signature); + ByteBuffer signature = reader.read(offset, HDF5_FILE_SIGNATURE_LENGTH); + byte[] sigBytes = new byte[HDF5_FILE_SIGNATURE_LENGTH]; + signature.get(sigBytes); + return Arrays.equals(HDF5_FILE_SIGNATURE, sigBytes); } - catch(IOException e) { + catch(Exception e) { throw new H5RuntimeException("Failed to read from address: " + offset, e); } - // Verify signature - return Arrays.equals(HDF5_FILE_SIGNATURE, signature); } - public H5Superblock(BufferedInputStream bis, long address) { + public H5Superblock(H5ByteReader reader, long address) { // Calculated bytes for the super block header is = 56 int superBlockHeaderSize = 12; - long fileLocation = address + HDF5_FILE_SIGNATURE_LENGTH; - address += 12 + HDF5_FILE_SIGNATURE_LENGTH; - - ByteBuffer header = ByteBuffer.allocate(superBlockHeaderSize); - - try { - byte[] b = new byte[superBlockHeaderSize]; - bis.reset(); - bis.skip((int) fileLocation); - bis.read(b); - header.put(b); - } - catch(IOException e) { - throw new H5RuntimeException(e); - } - - header.order(LITTLE_ENDIAN); - header.rewind(); + long cursor = address + HDF5_FILE_SIGNATURE_LENGTH; try { + ByteBuffer header = reader.read(cursor, superBlockHeaderSize); + header.order(LITTLE_ENDIAN); + header.rewind(); + cursor += superBlockHeaderSize; // Version # of Superblock versionOfSuperblock = header.get(); @@ -125,19 +106,13 @@ public H5Superblock(BufferedInputStream bis, long address) { groupInternalNodeK = Short.toUnsignedInt(header.getShort()); // File Consistency Flags (skip) - address += 4; + cursor += 4; int nextSectionSize = 4 * sizeOfOffsets; - header = ByteBuffer.allocate(nextSectionSize); - - byte[] hb = new byte[nextSectionSize]; - bis.reset(); - bis.skip(address); - bis.read(hb); - header.put(hb); - address += nextSectionSize; + header = reader.read(cursor, nextSectionSize); header.order(LITTLE_ENDIAN); header.rewind(); + cursor += nextSectionSize; // Base Address baseAddressByte = Utils.readBytesAsUnsignedLong(header, sizeOfOffsets); @@ -152,7 +127,7 @@ public H5Superblock(BufferedInputStream bis, long address) { driverInformationBlockAddress = Utils.readBytesAsUnsignedLong(header, sizeOfOffsets); // Root Group Symbol Table Entry Address - rootGroupSymbolTableAddress = address; + rootGroupSymbolTableAddress = cursor; } catch(Exception e) { throw new H5RuntimeException("Failed to read superblock from address " + address, e); diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5AttributeMessage.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5AttributeMessage.java new file mode 100644 index 00000000000..9e778e8fded --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5AttributeMessage.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.io.hdf5.message; + +import java.nio.ByteBuffer; +import java.util.BitSet; + +import org.apache.sysds.runtime.io.hdf5.H5RootObject; + +/** + * Lightweight placeholder for attribute messages. We currently ignore attribute content but keep track of the + * bytes to ensure the buffer position stays consistent, logging that the attribute was skipped to aid debugging. + */ +public class H5AttributeMessage extends H5Message { + + public H5AttributeMessage(H5RootObject rootObject, BitSet flags, ByteBuffer bb) { + super(rootObject, flags); + if(bb.remaining() == 0) + return; + byte version = bb.get(); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Skipping attribute message v" + version + " (" + bb.remaining() + " bytes payload)"); + } + // consume the rest of the payload + bb.position(bb.limit()); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataLayoutMessage.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataLayoutMessage.java index 46c49c926c6..de364cb0b09 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataLayoutMessage.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataLayoutMessage.java @@ -30,21 +30,36 @@ public class H5DataLayoutMessage extends H5Message { + public static final byte LAYOUT_CLASS_COMPACT = 0; + public static final byte LAYOUT_CLASS_CONTIGUOUS = 1; + public static final byte LAYOUT_CLASS_CHUNKED = 2; + public static final byte LAYOUT_CLASS_VIRTUAL = 3; + private final long address; private final long size; + private final byte layoutClass; + private final byte layoutVersion; public H5DataLayoutMessage(H5RootObject rootObject, BitSet flags, ByteBuffer bb) { super(rootObject, flags); rootObject.setDataLayoutVersion(bb.get()); + layoutVersion = rootObject.getDataLayoutVersion(); rootObject.setDataLayoutClass(bb.get()); + layoutClass = rootObject.getDataLayoutClass(); this.address = Utils.readBytesAsUnsignedLong(bb, rootObject.getSuperblock().sizeOfOffsets); this.size = Utils.readBytesAsUnsignedLong(bb, rootObject.getSuperblock().sizeOfLengths); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Data layout (version=" + layoutVersion + ", class=" + layoutClass + ") address=" + + address + " size=" + size); + } } public H5DataLayoutMessage(H5RootObject rootObject, BitSet flags, long address, long size) { super(rootObject, flags); this.address = address; this.size = size; + this.layoutVersion = rootObject.getDataLayoutVersion(); + this.layoutClass = rootObject.getDataLayoutClass(); } @Override @@ -74,5 +89,12 @@ public long getAddress() { public long getSize() { return size; } + + public byte getLayoutClass() { + return layoutClass; + } + public byte getLayoutVersion() { + return layoutVersion; + } } diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataSpaceMessage.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataSpaceMessage.java index 68fa15f8e74..db6aae8444e 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataSpaceMessage.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataSpaceMessage.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.io.hdf5.Utils; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.BitSet; import java.util.stream.IntStream; @@ -74,7 +75,14 @@ public H5DataSpaceMessage(H5RootObject rootObject, BitSet flags, ByteBuffer bb) } // Calculate the total length by multiplying all dimensions - totalLength = IntStream.of(rootObject.getDimensions()).mapToLong(Long::valueOf).reduce(1, Math::multiplyExact); + totalLength = IntStream.of(rootObject.getLogicalDimensions()).mapToLong(Long::valueOf) + .reduce(1, Math::multiplyExact); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Dataspace rank=" + rootObject.getRank() + " dims=" + + Arrays.toString(rootObject.getLogicalDimensions()) + " => rows=" + rootObject.getRow() + + ", cols(flat)=" + + rootObject.getCol()); + } } @@ -97,7 +105,7 @@ public void toBuffer(H5BufferBuilder bb) { // Dimensions sizes if(rootObject.getRank() != 0) { for(int i = 0; i < rootObject.getRank(); i++) { - bb.write(rootObject.getDimensions()[i], rootObject.getSuperblock().sizeOfLengths); + bb.write(rootObject.getLogicalDimensions()[i], rootObject.getSuperblock().sizeOfLengths); } } // Max dimension sizes diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataTypeMessage.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataTypeMessage.java index cd004a11edc..ca08254175f 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataTypeMessage.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5DataTypeMessage.java @@ -35,6 +35,10 @@ public class H5DataTypeMessage extends H5Message { public H5DataTypeMessage(H5RootObject rootObject, BitSet flags, ByteBuffer bb) { super(rootObject, flags); doubleDataType = new H5DoubleDataType(bb); + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Datatype parsed (class=" + doubleDataType.getDataClass() + ", size=" + + doubleDataType.getSize() + ")"); + } } public H5DataTypeMessage(H5RootObject rootObject, BitSet flags, H5DoubleDataType doubleDataType) { diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5FilterPipelineMessage.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5FilterPipelineMessage.java new file mode 100644 index 00000000000..f812005a7f8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5FilterPipelineMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.io.hdf5.message; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collections; +import java.util.List; + +import org.apache.sysds.runtime.io.hdf5.H5RootObject; +import org.apache.sysds.runtime.io.hdf5.H5RuntimeException; +import org.apache.sysds.runtime.io.hdf5.Utils; + +/** + * Minimal parser for filter pipeline messages. We currently do not support any filters, and therefore + * fail fast if we encounter one so the user understands why the dataset cannot be read. + */ +public class H5FilterPipelineMessage extends H5Message { + + private final List filterIds = new ArrayList<>(); + + public H5FilterPipelineMessage(H5RootObject rootObject, BitSet flags, ByteBuffer bb) { + super(rootObject, flags); + byte version = bb.get(); + byte numberOfFilters = bb.get(); + // Skip 6 reserved bytes + bb.position(bb.position() + 6); + + for(int i = 0; i < Byte.toUnsignedInt(numberOfFilters); i++) { + int filterId = Utils.readBytesAsUnsignedInt(bb, 2); + int nameLength = Utils.readBytesAsUnsignedInt(bb, 2); + Utils.readBytesAsUnsignedInt(bb, 2); // flags + int clientDataLength = Utils.readBytesAsUnsignedInt(bb, 2); + + if(nameLength > 0) { + byte[] nameBytes = new byte[nameLength]; + bb.get(nameBytes); + } + for(int j = 0; j < clientDataLength; j++) { + Utils.readBytesAsUnsignedInt(bb, 4); + } + Utils.seekBufferToNextMultipleOfEight(bb); + filterIds.add(filterId); + } + + if(!filterIds.isEmpty()) { + if(H5RootObject.HDF5_DEBUG) { + System.out.println("[HDF5] Detected unsupported filter pipeline v" + version + " -> " + filterIds); + } + throw new H5RuntimeException("Encountered unsupported filtered dataset (filters=" + filterIds + "). " + + "Compressed HDF5 inputs are currently unsupported."); + } + } + + public List getFilterIds() { + return Collections.unmodifiableList(filterIds); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5Message.java b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5Message.java index 70bb0ebeb31..cb084b85af7 100644 --- a/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5Message.java +++ b/src/main/java/org/apache/sysds/runtime/io/hdf5/message/H5Message.java @@ -142,6 +142,12 @@ private static H5Message readMessage(H5RootObject rootObject, ByteBuffer bb, int case H5Constants.OBJECT_MODIFICATION_TIME_MESSAGE: return new H5ObjectModificationTimeMessage(rootObject, flags, bb); + case H5Constants.FILTER_PIPELINE_MESSAGE: + return new H5FilterPipelineMessage(rootObject, flags, bb); + + case H5Constants.ATTRIBUTE_MESSAGE: + return new H5AttributeMessage(rootObject, flags, bb); + default: throw new H5RuntimeException("Unrecognized message type = " + messageType); } diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java index 024f5c19d08..4b6a60227fc 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java +++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java @@ -30,6 +30,7 @@ import org.apache.hadoop.fs.Path; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.parser.BooleanIdentifier; import org.apache.sysds.parser.ConstIdentifier; import org.apache.sysds.parser.DataExpression; @@ -91,45 +92,43 @@ public MetaDataAll(String mtdFileName, boolean conditional, boolean parseMeta) { parseMetaDataParams(); } - public JSONObject readMetadataFile(String filename, boolean conditional) - { + public JSONObject readMetadataFile(String filename, boolean conditional) { JSONObject retVal = new JSONObject(); boolean exists = HDFSTool.existsFileOnHDFS(filename); boolean isDir = exists ? HDFSTool.isDirectory(filename) : false; // CASE: filename is a directory -- process as a directory - if( exists && isDir ) - { + if(exists && isDir) { for(FileStatus stat : HDFSTool.getDirectoryListing(filename)) { Path childPath = stat.getPath(); // gives directory name - if( !childPath.getName().startsWith("part") ) + if(!childPath.getName().startsWith("part")) continue; - try (BufferedReader br = new BufferedReader(new InputStreamReader( - IOUtilFunctions.getFileSystem(childPath).open(childPath)))) - { + try(BufferedReader br = new BufferedReader( + new InputStreamReader(IOUtilFunctions.getFileSystem(childPath).open(childPath)))) { JSONObject childObj = JSONHelper.parse(br); - for( Object obj : childObj.entrySet() ){ - @SuppressWarnings("unchecked") Map.Entry e = (Map.Entry) obj; + for(Object obj : childObj.entrySet()) { + @SuppressWarnings("unchecked") + Map.Entry e = (Map.Entry) obj; Object key = e.getKey(); Object val = e.getValue(); retVal.put(key, val); } } - catch( IOException e){ - raiseValidateError("for MTD file in directory, error parting part of MTD file with path " + childPath.toString() + ": " + e.getMessage(), conditional); + catch(IOException e) { + raiseValidateError("for MTD file in directory, error parting part of MTD file with path " + + childPath.toString() + ": " + e.getMessage(), conditional); } } } // CASE: filename points to a file - else if (exists) { + else if(exists) { Path path = new Path(filename); - try (BufferedReader br = new BufferedReader(new InputStreamReader( - IOUtilFunctions.getFileSystem(path).open(path)))) - { + try(BufferedReader br = new BufferedReader( + new InputStreamReader(IOUtilFunctions.getFileSystem(path).open(path)))) { retVal = new JSONObject(br); } - catch (Exception e){ + catch(Exception e) { raiseValidateError("error parsing MTD file with path " + filename + ": " + e.getMessage(), conditional); } } @@ -138,16 +137,15 @@ else if (exists) { } @SuppressWarnings("unchecked") - private void parseMetaDataParams() - { - for( Object obj : _metaObj.entrySet() ){ - Map.Entry e = (Map.Entry) obj; + private void parseMetaDataParams() { + for(Object obj : _metaObj.entrySet()) { + Map.Entry e = (Map.Entry) obj; Object key = e.getKey(); Object val = e.getValue(); boolean isValidName = DataExpression.READ_VALID_MTD_PARAM_NAMES.contains(key); - if (!isValidName){ //wrong parameters always rejected + if(!isValidName) { // wrong parameters always rejected raiseValidateError("MTD file contains invalid parameter name: " + key, false); } @@ -157,21 +155,40 @@ private void parseMetaDataParams() setFormatTypeString(null); } - private void parseMetaDataParam(Object key, Object val) - { + private void parseMetaDataParam(Object key, Object val) { switch(key.toString()) { - case DataExpression.READROWPARAM: _dim1 = val instanceof Long ? (Long) val : (Integer) val; break; - case DataExpression.READCOLPARAM: _dim2 = val instanceof Long ? (Long) val : (Integer) val; break; - case DataExpression.ROWBLOCKCOUNTPARAM: setBlocksize((Integer) val); break; - case DataExpression.READNNZPARAM: setNnz(val instanceof Long ? (Long) val : (Integer) val); break; - case DataExpression.FORMAT_TYPE: setFormatTypeString((String) val); break; - case DataExpression.DATATYPEPARAM: setDataType(Types.DataType.valueOf(((String) val).toUpperCase())); break; - case DataExpression.VALUETYPEPARAM: setValueType(Types.ValueType.fromExternalString((String) val)); break; - case DataExpression.DELIM_DELIMITER: setDelim(val.toString()); break; - case DataExpression.SCHEMAPARAM: setSchema(val.toString()); break; - case DataExpression.PRIVACY: setPrivacyConstraints((String) val); break; + case DataExpression.READROWPARAM: + _dim1 = val instanceof Long ? (Long) val : (Integer) val; + break; + case DataExpression.READCOLPARAM: + _dim2 = val instanceof Long ? (Long) val : (Integer) val; + break; + case DataExpression.ROWBLOCKCOUNTPARAM: + setBlocksize((Integer) val); + break; + case DataExpression.READNNZPARAM: + setNnz(val instanceof Long ? (Long) val : (Integer) val); + break; + case DataExpression.FORMAT_TYPE: + setFormatTypeString((String) val); + break; + case DataExpression.DATATYPEPARAM: + setDataType(Types.DataType.valueOf(((String) val).toUpperCase())); + break; + case DataExpression.VALUETYPEPARAM: + setValueType(Types.ValueType.fromExternalString((String) val)); + break; + case DataExpression.DELIM_DELIMITER: + setDelim(val.toString()); + break; + case DataExpression.SCHEMAPARAM: + setSchema(val.toString()); + break; + case DataExpression.PRIVACY: + setPrivacyConstraints((String) val); + break; case DataExpression.DELIM_HAS_HEADER_ROW: - if(val instanceof Boolean){ + if(val instanceof Boolean) { boolean valB = (Boolean) val; setHasHeader(valB); break; @@ -179,7 +196,9 @@ private void parseMetaDataParam(Object key, Object val) else setHasHeader(false); break; - case DataExpression.DELIM_SPARSE: setSparseDelim((boolean) val); break; + case DataExpression.DELIM_SPARSE: + setSparseDelim((boolean) val); + break; } } @@ -238,70 +257,72 @@ public void setDelim(String delim) { } public void setFormatTypeString(String format) { - _formatTypeString = _formatTypeString != null && format == null && _metaObj != null ? (String)JSONHelper.get(_metaObj, DataExpression.FORMAT_TYPE) : format ; + _formatTypeString = _formatTypeString != null && format == null && + _metaObj != null ? (String) JSONHelper.get(_metaObj, DataExpression.FORMAT_TYPE) : format; if(_formatTypeString != null && EnumUtils.isValidEnum(Types.FileFormat.class, _formatTypeString.toUpperCase())) setFileFormat(Types.FileFormat.safeValueOf(_formatTypeString)); } public void setPrivacyConstraints(String privacyConstraints) { - if (privacyConstraints != null && - !privacyConstraints.equals("private") && - !privacyConstraints.equals("private-aggregate") && - !privacyConstraints.equals("public")) { + if(privacyConstraints != null && !privacyConstraints.equals("private") && + !privacyConstraints.equals("private-aggregate") && !privacyConstraints.equals("public")) { throw new DMLRuntimeException("Invalid privacy constraint: " + privacyConstraints + ". Must be 'private', 'private-aggregate', or 'public'."); } _privacyConstraints = privacyConstraints; } - + public DataCharacteristics getDataCharacteristics() { return new MatrixCharacteristics(getDim1(), getDim2(), getBlocksize(), getNnz()); } @SuppressWarnings("unchecked") - public HashMap parseMetaDataFileParameters(String mtdFileName, boolean conditional, HashMap varParams) - { - for( Object obj : _metaObj.entrySet() ){ - Map.Entry e = (Map.Entry) obj; + public HashMap parseMetaDataFileParameters(String mtdFileName, boolean conditional, + HashMap varParams) { + for(Object obj : _metaObj.entrySet()) { + Map.Entry e = (Map.Entry) obj; Object key = e.getKey(); Object val = e.getValue(); boolean isValidName = DataExpression.READ_VALID_MTD_PARAM_NAMES.contains(key); - if (!isValidName){ //wrong parameters always rejected + if(!isValidName) { // wrong parameters always rejected raiseValidateError("MTD file " + mtdFileName + " contains invalid parameter name: " + key, false); } parseMetaDataParam(key, val); // if the read method parameter is a constant, then verify value matches MTD metadata file - if (varParams.get(key.toString()) != null && (varParams.get(key.toString()) instanceof ConstIdentifier) - && !varParams.get(key.toString()).toString().equalsIgnoreCase(val.toString())) { + if(varParams.get(key.toString()) != null && (varParams.get(key.toString()) instanceof ConstIdentifier) && + !varParams.get(key.toString()).toString().equalsIgnoreCase(val.toString())) { raiseValidateError("Parameter '" + key.toString() - + "' has conflicting values in metadata and read statement. MTD file value: '" - + val.toString() + "'. Read statement value: '" + varParams.get(key.toString()) + "'.", conditional); - } else { - // if the read method does not specify parameter value, then add MTD metadata file value to parameter list - if (varParams.get(key.toString()) == null){ - if (( !key.toString().equalsIgnoreCase(DataExpression.DESCRIPTIONPARAM) ) && - ( !key.toString().equalsIgnoreCase(DataExpression.AUTHORPARAM) ) && - ( !key.toString().equalsIgnoreCase(DataExpression.CREATEDPARAM) ) ) - { + + "' has conflicting values in metadata and read statement. MTD file value: '" + val.toString() + + "'. Read statement value: '" + varParams.get(key.toString()) + "'.", conditional); + } + else { + // if the read method does not specify parameter value, then add MTD metadata file value to parameter + // list + if(varParams.get(key.toString()) == null) { + if((!key.toString().equalsIgnoreCase(DataExpression.DESCRIPTIONPARAM)) && + (!key.toString().equalsIgnoreCase(DataExpression.AUTHORPARAM)) && + (!key.toString().equalsIgnoreCase(DataExpression.CREATEDPARAM))) { StringIdentifier strId = new StringIdentifier(val.toString(), this); - if ( key.toString().equalsIgnoreCase(DataExpression.DELIM_HAS_HEADER_ROW) - || key.toString().equalsIgnoreCase(DataExpression.DELIM_FILL) - || key.toString().equalsIgnoreCase(DataExpression.DELIM_SPARSE) - ) { + if(key.toString().equalsIgnoreCase(DataExpression.DELIM_HAS_HEADER_ROW) || + key.toString().equalsIgnoreCase(DataExpression.DELIM_FILL) || + key.toString().equalsIgnoreCase(DataExpression.DELIM_SPARSE)) { // parse these parameters as boolean values BooleanIdentifier boolId = null; - if (strId.toString().equalsIgnoreCase("true")) { + if(strId.toString().equalsIgnoreCase("true")) { boolId = new BooleanIdentifier(true, this); - } else if (strId.toString().equalsIgnoreCase("false")) { + } + else if(strId.toString().equalsIgnoreCase("false")) { boolId = new BooleanIdentifier(false, this); - } else { - raiseValidateError("Invalid value provided for '" + DataExpression.DELIM_HAS_HEADER_ROW + "' in metadata file '" + mtdFileName + "'. " - + "Must be either TRUE or FALSE.", conditional); + } + else { + raiseValidateError("Invalid value provided for '" + DataExpression.DELIM_HAS_HEADER_ROW + + "' in metadata file '" + mtdFileName + "'. " + "Must be either TRUE or FALSE.", + conditional); } varParams.remove(key.toString()); addVarParam(key.toString(), boolId, varParams); @@ -313,37 +334,37 @@ public HashMap parseMetaDataFileParameters(String mtdFileNam } } - else if ( key.toString().equalsIgnoreCase(DataExpression.DELIM_FILL_VALUE)) { + else if(key.toString().equalsIgnoreCase(DataExpression.DELIM_FILL_VALUE)) { // parse these parameters as numeric values DoubleIdentifier doubleId = new DoubleIdentifier(Double.parseDouble(strId.toString()), this); varParams.remove(key.toString()); addVarParam(key.toString(), doubleId, varParams); } - else if (key.toString().equalsIgnoreCase(DataExpression.DELIM_NA_STRINGS) - || key.toString().equalsIgnoreCase(DataExpression.PRIVACY) - || key.toString().equalsIgnoreCase(DataExpression.FINE_GRAINED_PRIVACY)) { + else if(key.toString().equalsIgnoreCase(DataExpression.DELIM_NA_STRINGS) || + key.toString().equalsIgnoreCase(DataExpression.PRIVACY) || + key.toString().equalsIgnoreCase(DataExpression.FINE_GRAINED_PRIVACY)) { String naStrings = null; - if ( val instanceof String) { + if(val instanceof String) { naStrings = val.toString(); } - else if (val instanceof JSONArray) { + else if(val instanceof JSONArray) { StringBuilder sb = new StringBuilder(); - JSONArray valarr = (JSONArray)val; - for(int naid=0; naid < valarr.size(); naid++ ) { - sb.append( (String) valarr.get(naid) ); - if ( naid < valarr.size()-1) - sb.append( DataExpression.DELIM_NA_STRING_SEP ); + JSONArray valarr = (JSONArray) val; + for(int naid = 0; naid < valarr.size(); naid++) { + sb.append((String) valarr.get(naid)); + if(naid < valarr.size() - 1) + sb.append(DataExpression.DELIM_NA_STRING_SEP); } naStrings = sb.toString(); } - else if ( val instanceof JSONObject ){ - JSONObject valJsonObject = (JSONObject)val; + else if(val instanceof JSONObject) { + JSONObject valJsonObject = (JSONObject) val; naStrings = valJsonObject.toString(); } else { - throw new ParseException("Type of value " + val - + " from metadata not recognized by parser."); + throw new ParseException( + "Type of value " + val + " from metadata not recognized by parser."); } StringIdentifier sid = new StringIdentifier(naStrings, this); varParams.remove(key.toString()); @@ -364,24 +385,29 @@ else if ( val instanceof JSONObject ){ } public void addVarParam(String name, Expression value, HashMap varParams) { - if (DMLScript.VALIDATOR_IGNORE_ISSUES && (value == null)) { + if(DMLScript.VALIDATOR_IGNORE_ISSUES && (value == null)) { return; } varParams.put(name, value); // if required, initialize values setFilename(value.getFilename()); - if (getBeginLine() == 0) setBeginLine(value.getBeginLine()); - if (getBeginColumn() == 0) setBeginColumn(value.getBeginColumn()); - if (getEndLine() == 0) setEndLine(value.getEndLine()); - if (getEndColumn() == 0) setEndColumn(value.getEndColumn()); - if (getText() == null) setText(value.getText()); + if(getBeginLine() == 0) + setBeginLine(value.getBeginLine()); + if(getBeginColumn() == 0) + setBeginColumn(value.getBeginColumn()); + if(getEndLine() == 0) + setEndLine(value.getEndLine()); + if(getEndColumn() == 0) + setEndColumn(value.getEndColumn()); + if(getText() == null) + setText(value.getText()); } public static String checkHasDelimitedFormat(String filename, boolean conditional) { // if the MTD file exists, check the format is not binary MetaDataAll mtdObject = new MetaDataAll(filename + ".mtd", conditional, false); - if (mtdObject.mtdExists()) { + if(mtdObject.mtdExists()) { try { mtdObject.setFormatTypeString((String) mtdObject._metaObj.get(DataExpression.FORMAT_TYPE)); if(Types.FileFormat.isDelimitedFormat(mtdObject.getFormatTypeString())) @@ -394,24 +420,20 @@ public static String checkHasDelimitedFormat(String filename, boolean conditiona return null; } - public static boolean checkHasMatrixMarketFormat(String inputFileName, String mtdFileName, boolean conditional) - { + public static boolean checkHasMatrixMarketFormat(String inputFileName, String mtdFileName, boolean conditional) { // Check the MTD file exists. if there is an MTD file, return false. MetaDataAll mtdObject = new MetaDataAll(mtdFileName, conditional, false); - if (mtdObject.mtdExists()) + if(mtdObject.mtdExists()) return false; - if( HDFSTool.existsFileOnHDFS(inputFileName) - && !HDFSTool.isDirectory(inputFileName) ) - { + if(HDFSTool.existsFileOnHDFS(inputFileName) && !HDFSTool.isDirectory(inputFileName)) { Path path = new Path(inputFileName); - try( BufferedReader in = new BufferedReader(new InputStreamReader( - IOUtilFunctions.getFileSystem(path).open(path)))) - { + try(BufferedReader in = new BufferedReader( + new InputStreamReader(IOUtilFunctions.getFileSystem(path).open(path)))) { String headerLine = new String(""); - if (in.ready()) + if(in.ready()) headerLine = in.readLine(); - return (headerLine !=null && headerLine.startsWith("%%")); + return(headerLine != null && headerLine.startsWith("%%")); } catch(Exception ex) { throw new LanguageException("Failed to read matrix market header.", ex); @@ -420,6 +442,13 @@ public static boolean checkHasMatrixMarketFormat(String inputFileName, String mt return false; } + public static String checkHasHDF5Format(String filename) { + if(filename != null && filename.toLowerCase().endsWith(".h5")) { + return FileFormat.HDF5.toString(); + } + return null; + } + @Override public String toString() { return "MetaDataAll\n" + _metaObj + "\n" + super.toString(); diff --git a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test.java b/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test.java index c20294cd85b..bd6af2fe066 100644 --- a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test.java +++ b/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test.java @@ -19,7 +19,18 @@ package org.apache.sysds.test.functions.io.hdf5; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.io.File; +import org.apache.commons.io.FileUtils; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; @@ -27,30 +38,52 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; -public abstract class ReadHDF5Test extends ReadHDF5TestBase { +public class ReadHDF5Test extends ReadHDF5TestBase { - protected abstract int getId(); + private static final double eps = 1e-9; + private static final String TEST_NAME = "ReadHDF5Test"; - protected String getInputHDF5FileName() { - return "transfusion_" + getId() + ".h5"; + private static final List TEST_CASES = Collections.unmodifiableList( + Arrays.asList(new Hdf5TestCase("test_single_dataset.h5", "data", DmlVariant.FORMAT_AND_DATASET), + new Hdf5TestCase("test_multiple_datasets.h5", "matrix_2d", DmlVariant.DATASET_ONLY), + new Hdf5TestCase("test_multiple_datasets.h5", "matrix_3d", DmlVariant.DATASET_ONLY), + new Hdf5TestCase("test_multi_tensor_samples.h5", "label", DmlVariant.DATASET_ONLY), + new Hdf5TestCase("test_multi_tensor_samples.h5", "sen1", DmlVariant.DATASET_ONLY), + new Hdf5TestCase("test_nested_groups.h5", "group1/subgroup/data2", DmlVariant.FORMAT_AND_DATASET))); + + @Override + protected String getTestName() { + return TEST_NAME; } - private final static double eps = 1e-9; + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } - @Test - public void testHDF51_Seq_CP() { - runReadHDF5Test(getId(), ExecMode.SINGLE_NODE, false); + @BeforeClass + public static void setUpClass() { + Path scriptDir = Paths.get(SCRIPT_DIR + TEST_DIR); + generateHdf5Data(scriptDir); } @Test - public void testHDF51_Parallel_CP() { - runReadHDF5Test(getId(), ExecMode.SINGLE_NODE, true); + public void testReadSequential() { + for(Hdf5TestCase tc : TEST_CASES) + runReadHDF5Test(tc, ExecMode.SINGLE_NODE, false); } - protected void runReadHDF5Test(int testNumber, ExecMode platform, boolean parallel) { + @Test + public void testReadSequentialParallelIO() { + for(Hdf5TestCase tc : TEST_CASES) + runReadHDF5Test(tc, ExecMode.SINGLE_NODE, true); + } + protected void runReadHDF5Test(Hdf5TestCase testCase, ExecMode platform, boolean parallel) { ExecMode oldPlatform = rtplatform; rtplatform = platform; @@ -61,21 +94,28 @@ protected void runReadHDF5Test(int testNumber, ExecMode platform, boolean parall boolean oldpar = CompilerConfig.FLAG_PARREADWRITE_TEXT; try { - CompilerConfig.FLAG_PARREADWRITE_TEXT = parallel; TestConfiguration config = getTestConfiguration(getTestName()); loadTestConfiguration(config); String HOME = SCRIPT_DIR + TEST_DIR; - String inputMatrixName = HOME + INPUT_DIR + getInputHDF5FileName(); // always read the same data - String datasetName = "DATASET_1"; + String inputMatrixName = HOME + INPUT_DIR + testCase.hdf5File; + + fullDMLScriptName = HOME + testCase.variant.getScriptName(); + programArgs = new String[] {"-args", inputMatrixName, testCase.dataset, output("Y")}; - fullDMLScriptName = HOME + getTestName() + "_" + testNumber + ".dml"; - programArgs = new String[] {"-args", inputMatrixName, datasetName, output("Y")}; + // Clean per-case output/expected to avoid reusing stale metadata between looped cases + String outY = output("Y"); + String expY = expected("Y"); + FileUtils.deleteQuietly(new File(outY)); + FileUtils.deleteQuietly(new File(outY + ".mtd")); + FileUtils.deleteQuietly(new File(expY)); + FileUtils.deleteQuietly(new File(expY + ".mtd")); fullRScriptName = HOME + "ReadHDF5_Verify.R"; - rCmd = "Rscript" + " " + fullRScriptName + " " + inputMatrixName + " " + datasetName + " " + expectedDir(); + rCmd = "Rscript" + " " + fullRScriptName + " " + inputMatrixName + " " + testCase.dataset + " " + + expectedDir(); runTest(true, false, null, -1); runRScript(true); @@ -90,4 +130,61 @@ protected void runReadHDF5Test(int testNumber, ExecMode platform, boolean parall DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + private static void generateHdf5Data(Path scriptDir) { + ProcessBuilder processBuilder = new ProcessBuilder("Rscript", "gen_HDF5_testdata.R"); + processBuilder.directory(scriptDir.toFile()); + processBuilder.redirectErrorStream(true); + + try { + Process process = processBuilder.start(); + StringBuilder output = new StringBuilder(); + try(BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) { + reader.lines().forEach(line -> output.append(line).append(System.lineSeparator())); + } + int exitCode = process.waitFor(); + if(exitCode != 0) + Assert.fail("Failed to execute gen_HDF5_testdata.R (exit " + exitCode + "):\n" + output); + } + catch(IOException e) { + Assert.fail("Unable to execute gen_HDF5_testdata.R: " + e.getMessage()); + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + Assert.fail("Interrupted while generating HDF5 test data."); + } + } + + private enum DmlVariant { + FORMAT_AND_DATASET("ReadHDF5_WithFormatAndDataset.dml"), DATASET_ONLY("ReadHDF5_WithDataset.dml"), + DEFAULT("ReadHDF5_Default.dml"); + + private final String scriptName; + + DmlVariant(String scriptName) { + this.scriptName = scriptName; + } + + public String getScriptName() { + return scriptName; + } + } + + private static final class Hdf5TestCase { + private final String hdf5File; + private final String dataset; + private final DmlVariant variant; + + private Hdf5TestCase(String hdf5File, String dataset, DmlVariant variant) { + this.hdf5File = hdf5File; + this.dataset = dataset; + this.variant = variant; + } + + @Override + public String toString() { + return hdf5File + "::" + dataset; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test2.java b/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test2.java deleted file mode 100644 index d6a4c763c34..00000000000 --- a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test2.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.test.functions.io.hdf5; - -public class ReadHDF5Test2 extends ReadHDF5Test { - - private final static String TEST_NAME = "ReadHDF5Test"; - private final static String TEST_CLASS_DIR = TEST_DIR + ReadHDF5Test2.class.getSimpleName() + "/"; - - protected String getTestName() { - return TEST_NAME; - } - - protected String getTestClassDir() { - return TEST_CLASS_DIR; - } - - protected int getId() { - return 2; - } -} diff --git a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test3.java b/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test3.java deleted file mode 100644 index 71a6b1762ec..00000000000 --- a/src/test/java/org/apache/sysds/test/functions/io/hdf5/ReadHDF5Test3.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.test.functions.io.hdf5; - -public class ReadHDF5Test3 extends ReadHDF5Test { - - private final static String TEST_NAME = "ReadHDF5Test"; - private final static String TEST_CLASS_DIR = TEST_DIR + ReadHDF5Test3.class.getSimpleName() + "/"; - - protected String getTestName() { - return TEST_NAME; - } - - protected String getTestClassDir() { - return TEST_CLASS_DIR; - } - - protected int getId() { - return 3; - } -} diff --git a/src/test/scripts/functions/io/hdf5/ReadHDF5Test_3.dml b/src/test/scripts/functions/io/hdf5/ReadHDF5_Default.dml similarity index 100% rename from src/test/scripts/functions/io/hdf5/ReadHDF5Test_3.dml rename to src/test/scripts/functions/io/hdf5/ReadHDF5_Default.dml diff --git a/src/test/scripts/functions/io/hdf5/ReadHDF5_Verify.R b/src/test/scripts/functions/io/hdf5/ReadHDF5_Verify.R index 2b977007dd2..925e092f724 100644 --- a/src/test/scripts/functions/io/hdf5/ReadHDF5_Verify.R +++ b/src/test/scripts/functions/io/hdf5/ReadHDF5_Verify.R @@ -26,5 +26,19 @@ options(digits=22) library("rhdf5") -Y = h5read(args[1],args[2],native = TRUE) -writeMM(as(Y, "CsparseMatrix"), paste(args[3], "Y", sep="")) +Y = h5read(args[1], args[2], native = TRUE) +dims = dim(Y) + +if(length(dims) == 1) { + # convert to a column matrix + Y_mat = matrix(Y, ncol = 1) +} else if(length(dims) > 2) { + # flatten everything beyond the first dimension into columns + perm = c(1, rev(seq(2, length(dims)))) + Y_mat = matrix(aperm(Y, perm), nrow = dims[1], ncol = prod(dims[-1])) +} else { + # for 2d , systemds treats it the same + Y_mat = Y +} + +writeMM(as(Y_mat, "CsparseMatrix"), paste(args[3], "Y", sep="")) diff --git a/src/test/scripts/functions/io/hdf5/ReadHDF5Test_2.dml b/src/test/scripts/functions/io/hdf5/ReadHDF5_WithDataset.dml similarity index 100% rename from src/test/scripts/functions/io/hdf5/ReadHDF5Test_2.dml rename to src/test/scripts/functions/io/hdf5/ReadHDF5_WithDataset.dml diff --git a/src/test/scripts/functions/io/hdf5/ReadHDF5Test_1.dml b/src/test/scripts/functions/io/hdf5/ReadHDF5_WithFormatAndDataset.dml similarity index 100% rename from src/test/scripts/functions/io/hdf5/ReadHDF5Test_1.dml rename to src/test/scripts/functions/io/hdf5/ReadHDF5_WithFormatAndDataset.dml diff --git a/src/test/scripts/functions/io/hdf5/gen_HDF5_testdata.R b/src/test/scripts/functions/io/hdf5/gen_HDF5_testdata.R new file mode 100644 index 00000000000..fb9fed140ab --- /dev/null +++ b/src/test/scripts/functions/io/hdf5/gen_HDF5_testdata.R @@ -0,0 +1,247 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Generate various HDF5 test files with different formats. +# Creates test files in the 'in' directory. + +if (!require("rhdf5", quietly = TRUE)) { + cat("Error: rhdf5 is not installed.\n") + quit(status = 1) +} + +SMALL_MATRIX_2D <- c(200, 40) +SMALL_MATRIX_3D <- c(15, 15, 5) +SMALL_TENSOR_4D_A <- c(120, 16, 16, 4) +SMALL_TENSOR_4D_B <- c(120, 16, 16, 5) +SMALL_LABEL_MATRIX <- c(120, 12) + +VECTOR_LENGTH <- 200 +STRING_ARRAY_LENGTH <- 30 + +CHUNK_SHAPE <- c(100, 20) + +write_matrix <- function(file_path, dataset_name, shape, generator = function(n) rnorm(n), storage.mode = "double", H5type = NULL) { + values <- generator(prod(shape)) + h5createDataset( + file_path, + dataset_name, + dims = rev(shape), + chunk = NULL, + filter = "NONE", # contiguous, uncompressed layout + level = 0, + shuffle = FALSE, + storage.mode = storage.mode, + H5type = H5type, + native = TRUE # use R column-major order, same in h5read(..., native=TRUE) in tests. + ) + h5write(array(values, dim = shape), file_path, dataset_name, native = TRUE) +} + +generate_test_file_single_dataset <- function(dir) { + file_path <- file.path(dir, "test_single_dataset.h5") + h5createFile(file_path) + write_matrix(file_path, "data", SMALL_MATRIX_2D) + cat("Created test_single_dataset.h5 (single 2D dataset)\n") +} + +generate_test_file_multiple_datasets <- function(dir) { + file_path <- file.path(dir, "test_multiple_datasets.h5") + h5createFile(file_path) + write_matrix(file_path, "matrix_2d", SMALL_MATRIX_2D) + # Create 1D vector without compression/filters + h5createDataset(file_path, "vector_1d", dims = VECTOR_LENGTH, chunk = NULL, filter = "NONE", level = 0, shuffle = FALSE) + h5write(rnorm(VECTOR_LENGTH), file_path, "vector_1d", native = TRUE) + write_matrix(file_path, "matrix_3d", SMALL_MATRIX_3D) + cat("Created test_multiple_datasets.h5 (1D/2D/3D datasets)\n") +} + +generate_test_file_different_dtypes <- function(dir) { + file_path <- file.path(dir, "test_different_dtypes.h5") + h5createFile(file_path) + # H5T_IEEE_F64LE (64-bit float) + write_matrix(file_path, "double_primary", SMALL_MATRIX_2D, storage.mode = "double") + # H5T_IEEE_F32LE (32-bit float) + write_matrix(file_path, "float32", SMALL_MATRIX_2D, H5type = "H5T_IEEE_F32LE") + # H5T_STD_I32LE (32-bit integer) + write_matrix( + file_path, + "int32", + SMALL_MATRIX_2D, + generator = function(n) as.integer(sample(-100:100, n, replace = TRUE)), + storage.mode = "integer" + ) + # H5T_STD_I64LE (64-bit integer) + write_matrix( + file_path, + "int64", + SMALL_MATRIX_2D, + generator = function(n) as.integer(sample(-100:100, n, replace = TRUE)), + H5type = "H5T_STD_I64LE" + ) + cat("Created test_different_dtypes.h5 (double/float/int32/int64 datasets)\n") +} + +# https://support.hdfgroup.org/documentation/hdf5-docs/advanced_topics/chunking_in_hdf5.html +generate_test_file_chunked <- function(dir) { + file_path <- file.path(dir, "test_chunked.h5") + h5createFile(file_path) + + data <- array(rnorm(prod(SMALL_MATRIX_2D)), dim = SMALL_MATRIX_2D) + + h5createDataset(file_path, "chunked_data", dims = SMALL_MATRIX_2D, chunk = CHUNK_SHAPE, + filter = "NONE", level = 0, shuffle = FALSE) + h5write(data, file_path, "chunked_data", native = TRUE) + + write_matrix(file_path, "non_chunked_data", SMALL_MATRIX_2D) + cat("Created test_chunked.h5 (chunked dataset)\n") +} + +generate_test_file_compressed <- function(dir) { + file_path <- file.path(dir, "test_compressed.h5") + h5createFile(file_path) + data <- array(rnorm(prod(SMALL_MATRIX_2D)), dim = SMALL_MATRIX_2D) + h5createDataset(file_path, "gzip_compressed_9", dims = SMALL_MATRIX_2D, + chunk = SMALL_MATRIX_2D, level = 9) + h5write(data, file_path, "gzip_compressed_9", native = TRUE) + h5createDataset(file_path, "gzip_compressed_1", dims = SMALL_MATRIX_2D, + chunk = SMALL_MATRIX_2D, level = 1) + h5write(data, file_path, "gzip_compressed_1", native = TRUE) + cat("Created test_compressed.h5 (gzip compression)\n") +} + +generate_test_file_multi_tensor_samples <- function(dir) { + file_path <- file.path(dir, "test_multi_tensor_samples.h5") + h5createFile(file_path) + write_matrix( + file_path, + "sen1", + SMALL_TENSOR_4D_A + ) + write_matrix( + file_path, + "sen2", + SMALL_TENSOR_4D_B + ) + write_matrix( + file_path, + "label", + SMALL_LABEL_MATRIX, + generator = function(n) as.integer(sample(0:1, n, replace = TRUE)) + ) + cat("Created test_multi_tensor_samples.h5 (multi-input tensors)\n") +} + +generate_test_file_nested_groups <- function(dir) { + file_path <- file.path(dir, "test_nested_groups.h5") + h5createFile(file_path) + write_matrix(file_path, "root_data", SMALL_MATRIX_2D) + h5createGroup(file_path, "group1") + write_matrix(file_path, "group1/data1", SMALL_MATRIX_2D) + h5createGroup(file_path, "group1/subgroup") + write_matrix(file_path, "group1/subgroup/data2", SMALL_MATRIX_2D) + cat("Created test_nested_groups.h5 (nested group hierarchy)\n") +} + +generate_test_file_with_attributes <- function(dir) { + file_path <- file.path(dir, "test_with_attributes.h5") + h5createFile(file_path) + write_matrix(file_path, "data", SMALL_MATRIX_2D) + + fid <- H5Fopen(file_path) + did <- H5Dopen(fid, "data") + h5writeAttribute("Test dataset with attributes", did, "description") + h5writeAttribute(1.0, did, "version") + h5writeAttribute(SMALL_MATRIX_2D, did, "shape") + H5Dclose(did) + + h5writeAttribute("2025-11-26", fid, "file_created") + h5writeAttribute("attributes", fid, "test_type") + H5Fclose(fid) + cat("Created test_with_attributes.h5 (dataset + file attributes)\n") +} + +generate_test_file_empty_datasets <- function(dir) { + file_path <- file.path(dir, "test_empty_datasets.h5") + h5createFile(file_path) + h5createDataset(file_path, "empty", dims = c(0, SMALL_MATRIX_2D[2]), + filter = "NONE", level = 0, shuffle = FALSE) + + h5createDataset(file_path, "scalar", dims = 1, + filter = "NONE", level = 0, shuffle = FALSE, chunk = 1) + h5write(1.0, file_path, "scalar", native = TRUE) + h5createDataset(file_path, "vector", dims = VECTOR_LENGTH, + filter = "NONE", level = 0, shuffle = FALSE, chunk = VECTOR_LENGTH) + h5write(rnorm(VECTOR_LENGTH), file_path, "vector", native = TRUE) + cat("Created test_empty_datasets.h5 (empty/scalar/vector)\n") +} + +generate_test_file_string_datasets <- function(dir) { + file_path <- file.path(dir, "test_string_datasets.h5") + h5createFile(file_path) + strings <- paste0("string_", 0:(STRING_ARRAY_LENGTH - 1)) + # Create string dataset without compression/filters + h5createDataset(file_path, "string_array", dims = STRING_ARRAY_LENGTH, + storage.mode = "character", filter = "NONE", level = 0, + shuffle = FALSE, chunk = STRING_ARRAY_LENGTH) + h5write(strings, file_path, "string_array", native = TRUE) + cat("Created test_string_datasets.h5 (string datasets)\n") +} + +main <- function() { + if (basename(getwd()) != "hdf5") { + cat("You must execute this script from the 'hdf5' directory\n") + quit(status = 1) + } + + testdir <- "in" + if (!dir.exists(testdir)) { + dir.create(testdir) + } + + test_functions <- list( + generate_test_file_single_dataset, + generate_test_file_multiple_datasets, + generate_test_file_different_dtypes, + generate_test_file_chunked, + generate_test_file_compressed, + generate_test_file_multi_tensor_samples, + generate_test_file_nested_groups, + generate_test_file_with_attributes, + generate_test_file_empty_datasets, + generate_test_file_string_datasets + ) + + for (test_func in test_functions) { + tryCatch({ + test_func(testdir) + }, error = function(e) { + cat(sprintf(" ✗ Error: %s\n", conditionMessage(e))) + }) + } + + files <- sort(list.files(testdir, pattern = "\\.h5$", full.names = TRUE)) + cat(sprintf("\nGenerated %d HDF5 test files in %s\n", length(files), normalizePath(testdir))) +} + +if (!interactive()) { + main() +}