11package de .dmi3y .behaiv .kernel ;
22
3- import com .google . gson . Gson ;
3+ import com .fasterxml . jackson . core . type . TypeReference ;
44import de .dmi3y .behaiv .kernel .logistic .LogisticUtils ;
55import de .dmi3y .behaiv .storage .BehaivStorage ;
6+ import de .dmi3y .behaiv .tools .Pair ;
67import org .apache .commons .lang3 .ArrayUtils ;
7- import org .apache .commons .math3 .util .Pair ;
88import org .ejml .simple .SimpleMatrix ;
99
1010import java .io .BufferedReader ;
@@ -41,19 +41,19 @@ public boolean isEmpty() {
4141 @ Override
4242 public void fit (ArrayList <Pair <ArrayList <Double >, String >> data ) {
4343 this .data = data ;
44- labels = this .data .stream ().map (Pair ::getSecond ).distinct ().collect (Collectors .toList ());
44+ labels = this .data .stream ().map (Pair ::getValue ).distinct ().collect (Collectors .toList ());
4545 if (readyToPredict ()) {
4646
4747
4848 //features
49- double [][] inputs = this .data .stream ().map (Pair ::getFirst ).map (l -> l .toArray (new Double [0 ]))
49+ double [][] inputs = this .data .stream ().map (Pair ::getKey ).map (l -> l .toArray (new Double [0 ]))
5050 .map (ArrayUtils ::toPrimitive )
5151 .toArray (double [][]::new );
5252
5353 //labels
5454 double [][] labelArray = new double [data .size ()][labels .size ()];
5555 for (int i = 0 ; i < data .size (); i ++) {
56- int dummyPos = labels .indexOf (data .get (i ).getSecond ());
56+ int dummyPos = labels .indexOf (data .get (i ).getValue ());
5757 labelArray [i ][dummyPos ] = 1.0 ;
5858 }
5959
@@ -107,22 +107,21 @@ public String predictOne(ArrayList<Double> features) {
107107
108108 @ Override
109109 public void save (BehaivStorage storage ) throws IOException {
110- if (theta == null && data == null ) {
110+ if (theta == null && ( data == null || data . isEmpty ()) ) {
111111 throw new IOException ("Not enough data to save, network data is empty" );
112112 }
113- if (labels == null ) {
113+ if (labels == null || labels . isEmpty () ) {
114114 String message ;
115115 message = "Kernel collected labels but failed to get data, couldn't save network." ;
116116 throw new IOException (message );
117117 }
118118 if (theta == null ) {
119119 super .save (storage );
120-
121120 } else {
122121 theta .saveToFileBinary (storage .getNetworkFile (id ).toString ());
123122 try (final BufferedWriter writer = new BufferedWriter (new FileWriter (storage .getNetworkMetadataFile (id )))) {
124- final Gson gson = new Gson ();
125- writer .write (gson . toJson (labels ));
123+
124+ writer .write (objectMapper . writeValueAsString (labels ));
126125 } catch (Exception e ) {
127126 e .printStackTrace ();
128127 }
@@ -139,8 +138,14 @@ public void restore(BehaivStorage storage) throws IOException {
139138 }
140139
141140 try (final BufferedReader reader = new BufferedReader (new FileReader (storage .getNetworkMetadataFile (id )))) {
142- final Gson gson = new Gson ();
143- labels = ((List <String >) gson .fromJson (reader .readLine (), labels .getClass ()));
141+ final TypeReference <ArrayList <String >> typeReference = new TypeReference <ArrayList <String >>() {
142+ };
143+ final String labelsData = reader .readLine ();
144+ if (labelsData == null ) {
145+ labels = new ArrayList <>();
146+ } else {
147+ labels = objectMapper .readValue (labelsData , typeReference );
148+ }
144149 } catch (IOException e ) {
145150 e .printStackTrace ();
146151 }
0 commit comments