Skip to content

Commit b12bb8f

Browse files
committed
PR review comments.
Signed-off-by: Robert Altena <[email protected]>
1 parent 56d1a68 commit b12bb8f

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

nd4j-examples/src/main/java/org/nd4j/examples/Nd4jEx16_Serialization.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.nd4j.linalg.factory.Nd4j;
55

66
import java.io.*;
7+
import java.util.Map;
78
import java.util.Objects;
89

910
/**
@@ -12,47 +13,64 @@
1213
*/
1314
public class Nd4jEx16_Serialization {
1415

15-
public static void main(String[] args) throws IOException {
16+
public static void main(String[] args) throws Exception {
1617
ClassLoader loader = Nd4jEx16_Serialization.class.getClassLoader(); // used to read files from resources.
1718

1819
// 1. binary format from stream.
1920
INDArray arrWrite = Nd4j.linspace(1,25,25).reshape(5,5);
2021
String pathname = "tmp.bin";
21-
DataOutputStream sWrite = new DataOutputStream(new FileOutputStream(new File(pathname )));
22-
Nd4j.write(arrWrite, sWrite);
2322

24-
DataInputStream sRead = new DataInputStream(new FileInputStream(new File(pathname )));
25-
INDArray arrRead = Nd4j.read(sRead);
23+
try(DataOutputStream sWrite = new DataOutputStream(new FileOutputStream(new File(pathname )))){
24+
Nd4j.write(arrWrite, sWrite);
25+
}
26+
27+
INDArray arrRead;
28+
try(DataInputStream sRead = new DataInputStream(new FileInputStream(new File(pathname )))){
29+
arrRead = Nd4j.read(sRead);
30+
}
31+
2632
// We now have our test matrix in arrRead
27-
System.out.println("Read from binary format:" );
33+
System.out.println("Read from binary stream:" );
2834
System.out.println(arrRead );
2935

3036

31-
// 2. Read the numpy npy (and npz) formats:
32-
File file = new File( Objects.requireNonNull(loader.getResource("twentyfive.npy")).getFile());
37+
// 2. Write and read the numpy npy format:
38+
File file = new File("nd4j.npy" );
39+
Nd4j.writeAsNumpy(arrRead, file ); // Try to read this file from Python: y = np.load('nd4j.npy')
40+
41+
arrRead = Nd4j.createFromNpyFile(file); // We can read these files from nd4j.
42+
System.out.println();
43+
System.out.println("Read from Numpy .npy format:" );
44+
System.out.println(arrRead);
45+
46+
47+
// 3. Read the numpy npz format:
48+
file = new File( Objects.requireNonNull(loader.getResource("numpyz.npz")).getFile());
3349

34-
INDArray x = Nd4j.createFromNpyFile(file); // Nd4j.createFromNpzFile for npz Numpy files.
50+
Map<String, INDArray> arrayMap = Nd4j.createFromNpzFile(file); //We get a map reading an .npz file.
3551
System.out.println();
36-
System.out.println("Read from Numpy .npyformat:" );
37-
System.out.println(x);
52+
System.out.println("Read from Numpy .npz format:" );
53+
System.out.println(arrayMap.get("arr_0")); //We know there are 2 arrays in the .npz file.
54+
System.out.println(arrayMap.get("arr_1"));
3855

3956

40-
// 3. binary format from file.
57+
// 4. binary format from file.
4158
file = new File(pathname);
4259
Nd4j.saveBinary(arrWrite, file);
4360
arrRead = Nd4j.readBinary(file );
4461
System.out.println();
45-
System.out.println("Read from Numpy .npyformat:" );
62+
System.out.println("Read from binary format:" );
4663
System.out.println(arrRead);
4764

4865

49-
// 4. read a csv file.
66+
// 5. read a csv file.
5067
file = new File( Objects.requireNonNull(loader.getResource("twentyfive.csv")).getFile());
5168
String Filename = file.getAbsolutePath();
5269
arrRead = Nd4j.readNumpy(Filename, ",");
5370
System.out.println();
5471
System.out.println("Read from csv format:" );
5572
System.out.println(arrRead);
73+
5674
}
5775

5876
}
794 Bytes
Binary file not shown.
-328 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)